使用Tensorflow进行Mnist数据集实战

MNIST数据集简介

在进行MNIST数据集实战之前,我们先简单了解一下MNIST数据集。
MNIST数据集的全称叫手写数字图片数据集,它包含了0~9共10种数字的手写图片,每种数字一共有7000张图片,采集自不同风格的真实手写图片,整个数据集有70000张图片。

其中60000张图片用作训练集,剩下的10000张图片用作测试集。

而其中每张图片的大小是28x28大小,并且只保留了灰度信息。

我们学写程序,都会从一个“hello world”开始,那么,我的深度学习之旅就从这个MNIST数据集开始吧。

下面具体说一下怎么使用这个数据集来训练模型。

第一步: 首先加载MINST数据集,现在这个过程在Tensorflow 2.0中很简单了,一行代码就可搞定。

每二步: 使用训练集来训练模型,每训练100次,打印看一下模型的损失值。以及用测试集来测试一下模型的性能。

第三步: 在训练完成后,把模型训练过程中的损失值和准确率使用Matplotlib画出来,使我们能够更直观地看到模型的训练曲线。

下面看一下具体代码实现流程:

首先我们要导入一些库,如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers
import matplotlib
from matplotlib import pyplot as plt

接下来,我们把MNIST数据集加载进来:

(x_train,y_train),(x_test,y_test) = datasets.mnist.load_data()
print('x_train:',x_train.shape,'y_train:',y_train.shape,'x_test:',x_test.shape,'y_test:',y_test.shape)
x_train: (60000, 28, 28) y_train: (60000,) x_test: (10000, 28, 28) y_test: (10000,)

可以看到,加载这个数据集只需使用一行代码,另外,我打印了这个数据集中训练集和测试集的一些信息。

下面,把训练集和测试集做成一个Tensorflow易于批量处理的Dataset:

batchsz = 512#批量处理大小,表示一次性可以处理512张图片

train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db = train_db.shuffle(1000)##把训练集中的数据打散,防止神经网络记忆住训练数据
train_db = train_db.batch(batchsz)#批量化处理
train_db = train_db.map(preprocess)#对训练数据集中的数据做预处理
train_db = train_db.repeat(20)#对整个训练数据集遍历20次进行训练

注意,其中有一个预处理函数preprocess,它的实现如下:

def preprocess(x,y):
    print(x.shape,y.shape)
    x = tf.cast(x,dtype=tf.float32)/255.#把图片中的灰度值归一化到0~1区间
    x = tf.reshape(x,[-1,28*28])#把每一张图片都打平成28*28的一维向量,方便神经网络模型处理
    y = tf.cast(y,dtype=tf.int32)#把训练集中的标签值类型转化为tf.int32类型
    y = tf.one_hot(y,depth=10)#然后把标签值处理为one_hot编码格式的数据
    
    return x,y

训练集处理好后,我们按着同样的流程处理一下测试集:

test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.batch(batchsz)
test_db = test_db.map(preprocess)

#查看批量处理后,训练集中的数据
sample_x,sample_y = next(iter(train_db))
print(sample_x.shape,sample_y.shape)
(512, 784) (512, 10)

通过打印信息,可以看到,批量处理后,训练集中的一批数据中有512张图片,每张图片被打平成784(28x28)和一维向量。同时标签数据中有512个标签,每个标签都是长度为10的向量。

准备工作都好了,开始训练模型吧。我这里使用张量的方式来实现这个模型。

def main():
    lr = 1e-2#学习率
    losses,accs = [],[]#用来保存损失值和准确率的数组
    
    #初始化模型需要用到的超参数
    w1,b1 = tf.Variable(tf.random.normal([784,256],stddev=0.1)), tf.Variable(tf.zeros([256]))
    w2,b2 = tf.Variable(tf.random.normal([256,128],stddev=0.1)), tf.Variable(tf.zeros([128]))
    w3,b3 = tf.Variable(tf.random.normal([128,10],stddev=0.1)),tf.Variable(tf.zeros([10]))

    #使用一个for循环进行迭代20次
    for step, (x,y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            #下面这个是用张量实现的一个简单的全连接层
            h1 = x@w1+b1
            h1 = tf.nn.relu(h1)#激活函数采用relu函数

            h2 = h1@w2+b2
            h2 = tf.nn.relu(h2)

            out = h2@w3+b3
    
            loss = tf.square(y-out)#这里是每一张图片的损失值
            loss = tf.reduce_mean(loss)#计算这批样本的损失值的均值
            
        grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])#根据损失值计算出各个超参数的梯度
        for p,g in zip([w1,b1,w2,b2,w3,b3],grads):
            p.assign_sub(g*lr)#对超参数进行梯度下降
    
        #训练100步后打印损失值
        if step%100 ==0:
            print(step,'loss:',float(loss))
            losses.append(float(loss))
            
        if step%100 ==0:
            #每训练100次后,测试模型
            total,total_correct = 0.,0.
            for x_test,y_test in test_db:
                h1 = x_test@w1+b1
                h1 = tf.nn.relu(h1)
                
                h2 = h1@w2+b2
                h2 = tf.nn.relu(h2)
                
                out = h2@w3+b3
                
                pred = tf.argmax(out,axis=1)#测试集的预测值
                y_test = tf.argmax(y_test,axis=1)#测试集的真实标签
                
                correct = tf.equal(y_test,pred)#计算预测正确的数据
                total_correct += tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()
                total += y_test.shape[0]
                
            print(step, 'Evaluate acc:', total_correct/total)
            accs.append(total_correct/total)
    
    #画模型的训练误差曲线
    plt.figure()
    x = [i*80 for i in range(len(losses))]
    plt.plot(x,losses,color='C0',marker='s',label='训练')
    plt.xlabel('Step')
    plt.ylabel('MSE')
    plt.legend()
    plt.savefig('train.svg')
    
    #画模型的测试准确率曲线
    plt.figure()
    plt.plot(x,accs,color='C1',marker='s',label='测试')
    plt.xlabel('Step')
    plt.ylabel('准确率')
    plt.legend()
    plt.savefig('test.svg')

最后,在这个函数调用main():

if __name__ == '__main__':
    main()

下面是运行截图:


模型的训练误差曲线
模型的测试准确率曲线
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,383评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,522评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,852评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,621评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,741评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,929评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,076评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,803评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,265评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,582评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,716评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,395评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,039评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,027评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,488评论 2 361
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,612评论 2 350