tensorflow笔记:第五讲 全连接网络基础(+断点续训)

  • 下载mnist数据集
    1. 首先四个文件下载打包好后直接拖入虚拟机中,下载地址在这里
文件 内容
train-images-idx3-ubyte.gz 训练集图片 - 55000 张 训练图片, 5000 张 验证图片
train-labels-idx1-ubyte.gz 训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz 测试集图片 - 10000 张 图片
t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签

底层的源码将会执行下载、解压、重构图片和标签数据来组成以下的数据集对象:

数据集 目的
data_sets.train 55000 组 图片和标签, 用于训练。
data_sets.validation 5000 组 图片和标签, 用于迭代验证训练的准确性。
data_sets.test 10000 组 图片和标签, 用于最终测试训练的准确性。
  1. 然后需要写一个input_data.py文件
    代码在这里:https://testerhome.com/topics/18906
  2. 最后运行程序就好啦
  • 测试过的代码

mnist_forward.py

#coding:utf-8
#版本信息:ubuntu18.04  python3.6.8  tensorflow 1.14.0
#作者:九除以三还是三哦  如有错误,欢迎评论指正!!
import tensorflow as tf

INPUT_NODE=784   #网络输入节点为784个(代表每张输入图片的像素个数)
OUTPUT_NODE=10   #输出节点为10个(表示输出为数字0-9的十分类
LAYER1_NODE=500  #隐藏层节点500个

def get_weight(shape,regularizer):
    #参数满足截断正态分布,并使用正则化
    w=tf.Variable(tf.truncated_normal(shape,stddev=0.1))
    #将每个参数的正则化损失加到总损失中
    if regularizer !=None:tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):
    b=tf.Variable(tf.zeros(shape))
    return b

def forward(x,regularizer):
    #由输入层到隐藏层的参数w1形状为[784,500]
    w1=get_weight([INPUT_NODE,LAYER1_NODE],regularizer)
    #由输入层到隐藏的偏置b1形状为长度500的一维数组
    b1=get_bias([LAYER1_NODE])
    #前向传播结构第一层为输入 x与参数 w1矩阵相乘加上偏置 b1 ,再经过relu函数 ,得到隐藏层输出 y1。
    y1=tf.nn.relu(tf.matmul(x,w1)+b1)
    #由隐藏层到输出层的参数w2形状为[500,10]
    w2=get_weight([LAYER1_NODE,OUTPUT_NODE],regularizer)
    #由隐藏层到输出的偏置b2形状为长度10的一维数组
    b2=get_bias([OUTPUT_NODE])
    #前向传播结构第二层为隐藏输出 y1与参 数 w2 矩阵相乘加上偏置 矩阵相乘加上偏置 b2,得到输出 y。
    #由于输出 。由于输出 y要经过softmax oftmax 函数,使其符合概率分布,故输出y不经过 relu函数
    y=tf.matmul(y1,w2)+b2
    return y

mnist_backward.py

#coding:utf-8
#版本信息:ubuntu18.04  python3.6.8  tensorflow 1.14.0
#引入tensorflow、input_data、前向传播mnist_forward和os模块
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE=200  #每轮喂入神经网络的图片数
LEARNING_RATE_BASE=0.1  #初始学习率
LEARNING_RATE_DECAY=0.99  #学习率衰减率
REGULARIZER=0.0001  #正则化系数
STEPS=50000   #训练轮数
MOVING_AVERAGE_DECAY=0.99   #滑动平均衰减率
MODEL_SAVE_PATH="./model/"  #模型保存路径
MODEL_NAME="mnist_model"   #模型保存名称
 

def backward(mnist):

    #用placeholder给训练数据x和标签y_占位
    x=tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
    y_=tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
    #调用mnist_forward文件中的前向传播过程forword()函数,并设置正则化,计算训练数据集上的预测结果y
    y=mnist_forward.forward(x,REGULARIZER)
    #当前计算轮数计数器赋值,设定为不可训练类型
    global_step=tf.Variable(0,trainable=False)

    #调用包含所有参数正则化损失的损失函数lossce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
    cem=tf.reduce_mean(ce)
    loss=cem+tf.add_n(tf.get_collection('losses'))

    learning_rate=tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    #使用梯度衰减算法对模型优化,降低损失函数
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)  
   
 #定义参数的滑动平均   ema=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
    ema_op=ema.apply(tf.trainable_variables())
    #实例化可还原滑动平均的saver 
    #在模型训练时引入滑动平均可以使模型在测试数据上表现的更加健壮
    with tf.control_dependencies([train_step,ema_op]):
        train_op=tf.no_op(name='train')

    saver=tf.train.Saver()

    with tf.Session() as sess:
       #所有参数初始化
        init_op=tf.global_variables_initializer()
        sess.run(init_op)

        #断点续训,加入ckpt操作
        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        #每次喂入batch_size组(即200组)训练数据和对应标签,循环迭代steps轮
        for i in range(STEPS):
            xs,ys=mnist.train.next_batch(BATCH_SIZE)
            _,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
            if i%1000==0:
                print("After %d training step(s),loss on training batch is %g."%(step,loss_value))
                saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)


def main():
    #读入mnist
    mnist=input_data.read_data_sets("./data/",one_hot=True)
     #反向传播
    backward(mnist)

if __name__=='__main__':
    main()

mnist_test.py

#coding:utf-8
#版本信息:ubuntu18.04  python3.6.8  tensorflow 1.14.0
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
TEST_INTERVAL_SECS=5#程序5秒的循环间隔时间

def test(mnist):#利用tf.Graph()复现之前定义的计算图
    with tf.Graph().as_default() as g:
        #利用placeholder给训练数据x和标签y_占位
        x=tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
        y_=tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
        #调用mnist_forward文件中的前向传播过程forword()函数
        y=mnist_forward.forward(x,None)

#实例化具有滑动平均的saver对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性        ema=tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore=ema.variables_to_restore()
        savor=tf.train.Saver(ema_restore)

        #计算模型在测试集上的准确率
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

        while True:#加载指定路径下的ckpt
            with tf.Session() as sess:
                ckpt=tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                #若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率
                if ckpt and ckpt.model_checkpoint_path:
                    savor.restore(sess,ckpt.model_checkpoint_path)
                    global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
                    print("After %s training step(s),test accuracy=%g"%(global_step,accuracy_score))
                #若模型不存在,则打印出模型不存在的提示,从而test()函数完成
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():
    #加载指定路径下的测试数据集
    mnist=input_data.read_data_sets("./data/",one_hot=True)
    test(mnist)

if __name__=='__main__':
    main()
  • 运行结果

从终端显示的运行结果可以看出,随着训练轮数的增加,网络模型的损失函数值在不断降低,训练集上的精确度也在不断提高,具有良好的泛化能力。


数字识别.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,692评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,482评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,995评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,223评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,245评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,208评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,091评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,929评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,346评论 1 311
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,570评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,739评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,437评论 5 344
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,037评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,677评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,833评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,760评论 2 369
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,647评论 2 354