MNIST手写数字数据集读取方法_mnist怎么读-程序员宅基地

技术标签: 深度学习  MNIST数据读取  TensorFlow  

MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例。
数据集下载网址:http://yann.lecun.com/exdb/mnist/
数据集简介:
1、共有4数据集,下载之后保存在磁盘中(最好放在你代码执行目录下,方便后期使用。)如新建一个文件夹D:*****\MNIST_data存放数据。
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
2、此数据集中,
训练样本:共60000个,其中55000个用于训练,另外5000个用于验证
测试样本:共10000个
3、数据集中像素值
a)使用python读取二进制文件方法读取mnist数据集,则读进来的图像像素值为0-255之间;标签是0-9的数值。
b)采用TensorFlow的封装的函数读取mnist,则读进来的图像像素值为0-1之间;标签是0-1值组成的大小为1*10的行向量。

方法一:使用python的open()和struct.unpack_from()函数操作
【注意:此方法需要将下载的压缩文件解压之后才有使用】
1、首先观察一下mnist的结构,选取train-images为例
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number #文件头魔数
0004 32 bit integer 60000 number of images #图像个数
0008 32 bit integer 28 number of rows #图像宽度
0012 32 bit integer 28 number of columns #图像高度
0016 unsigned byte ?? pixel #图像像素值
0017 unsigned byte ?? pixel
……..
xxxx unsigned byte ?? pixel

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
……..
xxxx unsigned byte ?? label
The labels values are 0 to 9.
2、读取流程如下:
这里写图片描述

3、具体代码如下:

import numpy as np
import struct
import matplotlib.pyplot as plt

# 训练集文件
train_images_idx3_ubyte_file = 'MNIST_data/train-images.idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = 'MNIST_data/train-labels.idx1-ubyte'

# 测试集文件
test_images_idx3_ubyte_file = 'MNIST_data/t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = 'MNIST_data/t10k-labels.idx1-ubyte'


def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>iiii' #因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))

    # 解析数据集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)  #获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    print(offset)
    fmt_image = '>' + str(image_size) + 'B'  #图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    print(fmt_image,offset,struct.calcsize(fmt_image))
    images = np.empty((num_images, num_rows, num_cols))
    #plt.figure()
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
            print(offset)
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
        #print(images[i])
        offset += struct.calcsize(fmt_image)
#        plt.imshow(images[i],'gray')
#        plt.pause(0.00001)
#        plt.show()
    #plt.show()

    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print ('已解析 %d' % (i + 1) + '张')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    """
    TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  60000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    """
    TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  60000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.

    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    """
    TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  10000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    """
    TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  10000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.

    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)



if __name__ == '__main__':
    train_images = load_train_images()

    train_labels = load_train_labels()
    # test_images = load_test_images()
    # test_labels = load_test_labels()

    # 查看前十个数据及其标签以读取是否正确
    for i in range(10):
        print(train_labels[i])
        plt.imshow(train_images[i], cmap='gray')
        plt.pause(0.000001)
        plt.show()
    print('done')

方法二:使用TensorFlow封装代码读取
【注意:此方法,对下载的数据集压缩包不需要解压,代码会自己解压。】
TensorFlow的封装让使用MNIST数据集变得更加方便。MNIST数据集是NIST数据集的一个子集,它包含了60000张图片作为训练数据,10000张图片作为测试数据。在MNIST数据集中的每一张图片都代表了0~9中的一个数字。图片的大小都为28*28,且数字都会出现在图片的正中间。
这里写图片描述

具体读取代码如下:

import tensorflow as tf
import matplotlib.pyplot as plt

''' 读取MNIST数据方法一'''
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
'''1)获得数据集的个数'''
train_nums = mnist.train.num_examples
validation_nums = mnist.validation.num_examples
test_nums = mnist.test.num_examples
print('MNIST数据集的个数')
print(' >>>train_nums=%d' % train_nums,'\n',
      '>>>validation_nums=%d'% validation_nums,'\n',
      '>>>test_nums=%d' % test_nums,'\n')

'''2)获得数据值'''
train_data = mnist.train.images   #所有训练数据
val_data = mnist.validation.images  #(5000,784)
test_data = mnist.test.images       #(10000,784)
print('>>>训练集数据大小:',train_data.shape,'\n',
      '>>>一副图像的大小:',train_data[0].shape)
'''3)获取标签值label=[0,0,...,0,1],是一个1*10的向量'''
train_labels = mnist.train.labels     #(55000,10)
val_labels = mnist.validation.labels  #(5000,10)
test_labels = mnist.test.labels       #(10000,10)

print('>>>训练集标签数组大小:',train_labels.shape,'\n',
      '>>>一副图像的标签大小:',train_labels[1].shape,'\n',
      '>>>一副图像的标签值:',train_labels[0])

'''4)批量获取数据和标签【使用next_batch(batch_size)】'''
batch_size = 100    #每次批量训练100幅图像
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
print('使用mnist.train.next_batch(batch_size)批量读取样本\n')
print('>>>批量读取100个样本:数据集大小=',batch_xs.shape,'\n',
      '>>>批量读取100个样本:标签集大小=',batch_ys.shape)
#xs是图像数据(100,784);ys是标签(100,10)

'''5)显示图像'''
plt.figure()
for i in range(100):
    im = train_data[i].reshape(28,28)
    im = batch_xs[i].reshape(28,28)
    plt.imshow(im,'gray')
    plt.pause(0.0000001)
plt.show()

显示结果:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
MNIST数据集的个数
 >>>train_nums=55000 
 >>>validation_nums=5000 
 >>>test_nums=10000 
>>>训练集数据大小: (55000, 784) 
 >>>一副图像的大小: (784,)
>>>训练集标签数组大小: (55000, 10) 
 >>>一副图像的标签大小: (10,) 
 >>>一副图像的标签值: [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
使用mnist.train.next_batch(batch_size)批量读取样本
>>>批量读取100个样本:数据集大小= (100, 784) 
 >>>批量读取100个样本:标签集大小= (100, 10)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/panrenlong/article/details/81736754

智能推荐

c语言数组的排序递归,C语言:用递归的方式对数组排序-程序员宅基地

文章浏览阅读1.1k次。原题如下:编写程序,要求用户录入一串整数(把这串整数存储在数组中),然后通过调用selection_sort函数来排列这些整数。在给定n个元素的数组后,election_sort函数必须做下列工作:搜索数组找出最大的元素,然后把它移到数组的最后面;递归的调用函数本身对前面的n-1个数组元素进行排序。下面是我自己写的程序:(运行结果不对,求大神指导!!!)#include #define N 8in..._c语音递归n从高到低排序

ElasticSearch亿级数据毫秒查询实现_亿级数据模糊查询用什么数据库-程序员宅基地

文章浏览阅读6.1k次,点赞5次,收藏12次。面临问题:很多时候数据量大了,特别是有几亿条数据的时候,可能你会发现,跑个搜索怎么一下 5~10s。第一次搜索的时候,是 5~10s,后面反而就快了,可能就几百毫秒。说实话,ES 性能优化是不可能随手调一个参数,就可以万能的应对所有的性能慢的场景。也许有的场景是你换个参数,或者调整一下语法,就可以搞定,但是绝对不是所有场景都可以这样。性能优化:Filesystem Cache你往 ES..._亿级数据模糊查询用什么数据库

Python量化交易学习笔记(35)——backtrader多股回测避坑2_indexerror: array assignment index out of range-程序员宅基地

文章浏览阅读6.8k次。本文继续记录多股回测时遇到的异常情况。坑描述backtrader在读取日线数据时,会自动给date数据添加“时:分:秒.毫秒(23:59:59.999990)”信息。而通常用户在指定回测周期的开始和结束日期时,只会精确到日,时分秒信息会被backtrader默认以0补全。由于上述两个事实的存在,假如用户指定回测周期的结束日期有日线数据(由于非交易日、停盘等原因,可能没有日线数据),那么在backtrader中,回测周期的结束时间就会被设定为该日的00:00:00,而backtrader读_indexerror: array assignment index out of range

检索之 乘积量化(Product Quantization)_乘积量化检索-程序员宅基地

文章浏览阅读1.9k次。本文转载自:https://www.cnblogs.com/mafuqiang/p/7161592.html乘积量化1。简介  乘积量化(PQ)算法是和VLAD算法是由法国INRIA实验室一同提出来的,为的是加快图像的检索速度,所以它是一种检索算法,在矢量量化(Vector Quantization,VQ)的基础上发展而来,虽然PQ不算是新算法,但是这种思想还是挺有用处的,本文没有添加公式。  它..._乘积量化检索

样本熵(Python实现)_样本熵电池代码-程序员宅基地

文章浏览阅读6.2k次,点赞8次,收藏39次。1. 基本概念1.1 熵熵原本是一个热力学概念,是用来描述热力学系统混乱(无序)程度的度量。在信息论建立之后,关于上的概念和理论得到了发展。作为衡量时间序列中新信息发生率的非线性动力学参数,熵在众多的科学领域得到了应用。八十年代最常用的熵的算法是K-S熵及由它发展来的E-R熵,但这两种熵的计算即使对于维数很低的混沌系统也需要上万点的数据,而且它们对于噪声很敏感,时间序列叠加了随机噪声后这两种熵的计算可能不收敛。1.2 近似熵近似熵(APEN, Aproximate Entropy),是由Pincus_样本熵电池代码

BeanFactory和FactoryBean以及ApplicationContext的区别_beanfactory与factorybean和applicationcontext的区别-程序员宅基地

文章浏览阅读932次。BeanFactoryBeanFactory是IOC最基本的容器,负责生产和管理bean,它在为其他具体的IOC容器提供了最基本的规范,例如XmlBeanFactory、ApplicationContext等具体的挺起都实现了BeanFactory,再在其基础上附加了其他功能BeanFactory源码package org.springframework.beans.factory; ..._beanfactory与factorybean和applicationcontext的区别

随便推点

「Arm Arch」 虚拟化微架构_hypervisor arm-程序员宅基地

文章浏览阅读280次。全文3000字,预计阅读时长:8分钟适用于从事ARM软硬件设计、开发、调试的工程师、教师以及学生对于大部分开发者来讲,ARM架构知识一直存放于盲盒之中,知之甚少;而ARM架构知识是ARM结构化知识中非常关键的一部分,它的缺失,会导致我们对于问题的系统化思考难以进行。所以增设了《ARM架构知多少-A系列》专栏来和大家一起学习ARM架构,完善知识结构,拓展系统思考边界。_hypervisor arm

【Xilinx】基于DMA的adc读取_adc数据长度为什么是1920-程序员宅基地

文章浏览阅读1.1k次。硬件环境:ZYNQ7000软件环境:petalinux2018.2 xilinx_vivado_sdk2018.2学习例程:1、DMA初始化1)定义变量//定义ioctrl的命令#define AXI_ADC_IOCTL_BASE 'W'#define AXI_ADC_SET_SAMPLE_NUM _IO(AXI_ADC_IOCTL_BASE, 0)#define AXI_ADC_SET_DMA_LEN_..._adc数据长度为什么是1920

审计日志在分布式系统中的应用_审计日志定义-程序员宅基地

文章浏览阅读5.4k次,点赞3次,收藏9次。前言分布式系统的执行环境往往是异常复杂的,很多情况涉及到多节点间的消息通信。相比较于单节点系统而言,分布式系统在问题追踪,排查方面显然也复杂很多。那么这个时候,在分布式系统中,增加哪些类型的日志数据,来帮助我们发现和定位问题呢?答案就是我们今天将要阐述的审计日志(Audit log)。审计日志的概念很多人可能在想这样一个问题:同样是日志,审计日志和普通的日志,区别在于哪里呢?审计日志,..._审计日志定义

conda虚拟环境总结与解读_conda 环境-程序员宅基地

文章浏览阅读7k次,点赞19次,收藏54次。csdn上有很多关于conda的文章,但是一直没有一个宏观一些的文章,我将从宏观角度出发,对文章进行一个整合,解读,将新同学从conda环境入门到配置,应用全流程进行解读。当然,这篇文章因为是宏观一些,可能很多同学不能一次性看懂,没关系,这个可以反复看,在不同阶段都可以提供帮助。......_conda 环境

HashMap与TreeMap的排序以及四种遍历方式_treehash的遍历方法-程序员宅基地

文章浏览阅读593次。一、Map概述1、Map是将键映射到值( key-value )的对象。一个映射不能包含重复的键;每个键最多只能映射到一个值。2、Map与Collection的区别(1)Map 是以键值对的方式存储元素,键唯一,值可以重复。(2)Collection存储的是单列元素,子接口Set元素唯一,子接口List可以重复。(3)Map的数据结构针对键有效,跟值无关..._treehash的遍历方法

K8S调度机制和Pod基本故障排查_pod崩溃一般怎么排查-程序员宅基地

文章浏览阅读704次。目录一、调度约束过程解析1.1、调度方式1.2、示例1 nodeName1.3、示例2 nodeSelector二、故障排除2.1、故障现象2.2、排查思路一、调度约束过程解析1、首先,用户可以通过kubectl命令或者dashborad、API调用的方式(用作开发)来创建资源,和管理资源(Kubernetes通过watch的机制进行每个组件的协作,每个组件之间的设计实现了解耦)2、用户提交创建资源的请求给API Server,API Server将创建资源的元信息(属性信息)写入到etcd中,et_pod崩溃一般怎么排查