mnist-数据集的学习-3.6手写数字识别 # coding: utf-8# 指定文件编码为utf-8防止中文注释/中文打印乱码# 兼容Python版本尝试导入网络请求库try:# Python3内置网络下载工具用于远程下载MNIST数据集文件import urllib.request# 如果导入失败说明是Python2环境主动抛出错误提示用户except ImportError:raise ImportError(You should use Python 3.x)# 绘图库展示手写数字图片import matplotlib.pyplot as plt# 随机库随机抽取样本图片import random# 处理文件路径相关工具import os.path# 解压gz压缩包MNIST原始数据是.gz格式import gzip# 序列化/反序列化用来把数据集保存为pkl缓存文件加速下次读取import pickle# 文件系统操作创建路径、判断文件是否存在等import os# 数值计算核心库存储、处理图像像素数据import numpy as np# -------------------------- 全局字体配置彻底解决中文方框乱码 --------------------------plt.rcParams[font.sans-serif] [SimHei] # Windows黑体支持中文plt.rcParams[axes.unicode_minus] False # 解决图像负号显示异常# MNIST官方数据集亚马逊镜像下载地址url_base https://ossci-datasets.s3.amazonaws.com/mnist/ # mirror site# 字典存储4个MNIST数据文件的文件名key_file {train_img:train-images-idx3-ubyte.gz, # 训练集图片压缩包train_label:train-labels-idx1-ubyte.gz, # 训练集标签压缩包test_img:t10k-images-idx3-ubyte.gz, # 测试集图片压缩包test_label:t10k-labels-idx1-ubyte.gz # 测试集标签压缩包}# 获取当前mnist.py脚本所在的文件夹绝对路径dataset_dir os.path.dirname(os.path.abspath(__file__))# 拼接缓存文件完整路径脚本目录下的mnist.pklsave_file dataset_dir /mnist.pkl# 训练样本总数6万张train_num 60000# 测试样本总数1万张test_num 10000# 单张图片原始维度(通道数1, 高28像素, 宽28像素)img_dim (1, 28, 28)# 单张图片展平后一维长度28*28784img_size 784def _download(file_name):私有函数下载单个gz数据集文件已存在则跳过# 拼接文件完整本地路径脚本目录 文件名file_path dataset_dir / file_name# 判断本地是否已有该文件存在直接返回不用重复下载if os.path.exists(file_path):return# 打印日志提示正在下载哪个文件print(Downloading file_name ... )# 远程下载文件第一个参数下载链接第二个参数本地保存路径urllib.request.urlretrieve(url_base file_name, file_path)# 下载完成提示print(Done)def download_mnist():公有调用函数循环下载key_file里全部4个gz文件# 遍历字典里4个文件名逐个调用下载函数for v in key_file.values():_download(v)def _load_label(file_name):私有函数读取标签gz文件转换为numpy数组返回# 拼接标签文件本地完整路径file_path dataset_dir / file_name# 打印转换日志print(Converting file_name to NumPy Array ...)# 以二进制只读模式打开gz压缩文件with gzip.open(file_path, rb) as f:# 读取全部二进制数据转uint8无符号整数offset8跳过文件头部8字节描述信息labels np.frombuffer(f.read(), np.uint8, offset8)# 转换完成提示print(Done)# 返回全部标签数组形状(样本数,)数值0~9return labelsdef _load_img(file_name):私有函数读取图片gz文件转换为展平后的numpy数组返回# 拼接图片文件本地完整路径file_path dataset_dir / file_name# 打印转换日志print(Converting file_name to NumPy Array ...)# 二进制只读打开gz图片压缩包with gzip.open(file_path, rb) as f:# 读取二进制像素跳过头部16字节文件描述信息转uint8像素值(0~255)data np.frombuffer(f.read(), np.uint8, offset16)# 将一维像素数组重塑(样本数量, 784)每张图片展平为784维向量data data.reshape(-1, img_size)# 转换完成提示print(Done)# 返回展平后的图片数组形状(样本数,784)return datadef _convert_numpy():私有函数整合4个文件组装完整数据集字典# 定义空字典存放全部图像、标签数据dataset {}# 读取训练图片存入字典keydataset[train_img] _load_img(key_file[train_img])# 读取训练标签存入字典keydataset[train_label] _load_label(key_file[train_label])# 读取测试图片存入字典keydataset[test_img] _load_img(key_file[test_img])# 读取测试标签存入字典keydataset[test_label] _load_label(key_file[test_label])# 返回完整数据集字典return datasetdef init_mnist():初始化MNIST下载文件 转numpy 保存本地pkl缓存# 第一步下载全部4个gz原始文件download_mnist()# 第二步把gz文件转成numpy数组得到完整数据集dataset _convert_numpy()# 提示开始生成缓存文件print(Creating pickle file ...)# 二进制写入模式打开缓存文件with open(save_file, wb) as f:# 将数据集序列化存入pkl-1使用最高压缩协议pickle.dump(dataset, f, -1)# 缓存生成完成提示print(Done!)def _change_one_hot_label(X):私有函数普通数字标签转独热编码标签输入一维数组 [5,0,3...]输出二维数组 [[0,0,0,0,0,1,0,0,0,0], [1,0,0...], ...]# 创建全0矩阵行数样本总数列数数字类别100-9T np.zeros((X.size, 10))# 遍历每一个样本的索引和对应标签for idx, row in enumerate(T):# 将该行对应数字位置设为1其余保持0完成独热编码row[X[idx]] 1# 返回独热编码二维标签数组return Tdef load_mnist(normalizeTrue, flattenTrue, one_hot_labelFalse):对外主接口函数读取MNIST数据集优先读取本地缓存无缓存自动初始化下载Parameters----------normalize : boolTrue像素值从0~255归一化为0.0~1.0浮点数False保留0~255整数flatten : boolTrue每张图片展平为一维784向量False保留四维形状(样本数,1,28,28)灰度图one_hot_label : boolTrue标签转为10维独热编码False直接返回0~9数字Returns-------(训练图像, 训练标签), (测试图像, 测试标签)# 判断缓存文件mnist.pkl是否存在不存在则执行全套下载转换初始化if not os.path.exists(save_file):init_mnist()# 二进制读取缓存文件加载序列化数据集with open(save_file, rb) as f:dataset pickle.load(f)# 归一化分支处理if normalize:# 遍历训练、测试图片两个keyfor key in (train_img, test_img):# 将uint8像素转为32位浮点型dataset[key] dataset[key].astype(np.float32)# 全部像素除以255映射到0~1区间dataset[key] / 255.0# 独热编码标签分支处理if one_hot_label:# 训练标签转独热dataset[train_label] _change_one_hot_label(dataset[train_label])# 测试标签转独热dataset[test_label] _change_one_hot_label(dataset[test_label])# 不展平图片分支恢复原图四维结构if not flatten:# 遍历训练、测试图片for key in (train_img, test_img):# 重塑维度(样本数量, 通道1, 高28, 宽28)dataset[key] dataset[key].reshape(-1, 1, 28, 28)# 返回 ((训练图,训练标签), (测试图,测试标签)) 二元组return (dataset[train_img], dataset[train_label]), (dataset[test_img], dataset[test_label])# -------------------------- 简易神经网络工具函数 --------------------------def sigmoid(x):激活函数sigmoidreturn 1 / (1 np.exp(-x))def simple_two_layer_net_predict(x, W1, b1, W2, b2):两层全连接网络前向传播预测a1 np.dot(x, W1) b1z1 sigmoid(a1)a2 np.dot(z1, W2) b2return a2# 方式1直接在本文件运行测试 # 当前脚本直接执行时才运行下方测试代码被其他文件import导入时不执行if __name__ __main__:print( 测试1默认参数加载数据集 )# 调用主函数默认配置归一化展平一维普通数字标签(x_train, t_train), (x_test, t_test) load_mnist()# 打印训练图片数组维度60000张每张784像素print(训练图片 shape:, x_train.shape)# 打印训练标签数组维度60000个数字标签print(训练标签 shape:, t_train.shape)# 打印测试图片数组维度10000张每张784像素print(测试图片 shape:, x_test.shape)# 打印测试标签数组维度10000个数字标签print(测试标签 shape:, t_test.shape)# 打印第一张训练图片对应的数字标签print(第一张训练图片标签数字, t_train[0])print(\n 测试2自定义参数加载 )# 自定义参数不展平图片、标签转为独热编码(x_train_2, t_train_2), (x_test_2, t_test_2) load_mnist(flattenFalse, one_hot_labelTrue)# 打印未展平图片维度(60000, 1通道, 28高, 28宽)print(四维原图训练集 shape:, x_train_2.shape)# 打印独热标签维度60000行每行10个0/1print(独热编码训练标签 shape:, t_train_2.shape)# 打印第一张图片的独热编码数组print(第一张图片独热标签, t_train_2[0])# 从独热编码还原真实数字print(独热标签对应数字, np.argmax(t_train_2[0]))print(\n 测试3可视化手写数字图片 )# 1. 展示第一张图片img x_train[0].reshape(28,28)plt.figure(figsize(4,4))plt.imshow(img, cmapgray)plt.title(f手写数字{t_train[0]})plt.show()# 2. 随机抽取5张图片批量展示print(随机抽取5张手写数字样本)for _ in range(5):idx random.randint(0, len(x_train)-1)img x_train[idx].reshape(28, 28)label t_train[idx]plt.figure(figsize(3,3))plt.imshow(img, cmapgray)plt.title(f手写数字{label})plt.show()print(\n 测试4简易两层神经网络预测演示 )# 初始化两层网络权重与偏置W1 np.random.randn(784, 50)b1 np.zeros(50)W2 np.random.randn(50, 10)b2 np.zeros(10)# 对第一张图片做预测pred_out simple_two_layer_net_predict(x_train[0], W1, b1, W2, b2)pred_num np.argmax(pred_out)real_num t_train[0]print(f网络预测数字{pred_num}图片真实数字{real_num})print(权重随机初始化预测结果不准确后续训练后可提升识别准确率)