TensorFlow之使用RNN处理Mnist数据集

关于tf.nn.dynamic_rnn

tf.nn.dynamic_rnn 函数是tensorflow封装的用来实现递归神经网络(RNN)的函数。
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)

重要参数介绍:
  1. cell:LSTM、GRU等的记忆单元。cell参数代表一个LSTM或GRU的记忆单元,也就是一个cell。例如,cell = tf.nn.rnn_cell.LSTMCell((num_units),其中,num_units表示rnn cell中神经元个数,也就是下文的cell.output_size。返回一个LSTM或GRU cell,作为参数传入。
  2. inputs:输入的训练或测试数据,一般格式为[batch_size, max_time, embed_size],其中batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,embed_size表示嵌入的词向量的维度。
  3. sequence_length:是一个list,假设你输入了三句话,且三句话的长度分别是5,10,25,那么sequence_length=[5,10,25]。
  4. initial_state: (可选)RNN的初始state(状态)。如果cell.state_size(一层的RNNCell)是一个整数,那么它必须是一个具有适当类型和形状的张量[batch_size,cell.state_size]。如果cell.state_size是一个元组(多层的RNNCell,如MultiRNNCell),那么它应该是一个张量元组,每个元素的形状为[batch_size,s] for s in cell.state_size
  5. time_major:决定了输出tensor的格式,如果为True, 张量的形状必须为 [max_time, batch_size,cell.output_size]。如果为False, tensor的形状必须为[batch_size, max_time, cell.output_size],cell.output_size表示rnn cell中神经元个数。
返回值:元组(outputs, states)
  1. outputs:RNN的最后一层的输出,是一个tensor
    如果为time_major== False,则shape [batch_size,max_time,cell.output_size]。如果为time_major== True,则shape: [max_time,batch_size,cell.output_size]。
  2. state: RNN最后时间步的state,如果cell.state_size是一个整数(一般是单层的RNNCell),则state的shape:[batch_size,cell.state_size]。如果它是一个元组(一般这里是 多层的RNNCell),那么它将是一个具有相应形状的元组。注意:如果若RNNCell是 LSTMCells,则state将为每层cell的LSTMStateTuple的元组Tuple(LSTMStateTuple,LSTMStateTuple,LSTMStateTuple)
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow .examples .tutorials .mnist import input_data
mnist=input_data .read_data_sets ('data/',one_hot= True)
trainimgs, trainlabels, testimgs, testlabels \
 = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
 = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")

training_epochs = 10
batch_size      = 32
display_step    = 1

diminput=28
dimhidden=128
dimoutput=nclasses
nsteps=28#把输入拆分为28个小部分

weights={
    'hidden':tf.Variable(tf.random_normal([diminput ,dimhidden ])),
    'out':tf.Variable(tf.random_normal([dimhidden,dimoutput]))
}
biases={
    'hidden':tf.Variable(tf.random_normal([dimhidden])),
    'out':tf.Variable(tf.random_normal([dimoutput]))
}

def _RNN(_X,_W,_b,_nsteps,):
    #_X的格式为:[bitchsize,h,w]
    #_X=tf.transpose(_X,[1,0,2])#将第一维度和第二维度交换-无用
    '''
    tf.transpose(input, [dimension_1, dimenaion_2,..,dimension_n])
    这个函数主要适用于交换输入张量的不同维度用的,如果输入张量是二维,就相当是转置。
    dimension_n是整数,如果张量是三维,就是用0,1,2来表示。这个列表里的每个数对应相应的维度。
    如果是[2,1,0],就把输入张量的第三维度和第一维度交换。
    '''
    _X=tf.reshape(_X,[-1,diminput])#(batchsize*28,28)
    _H=tf.matmul(_X,_W['hidden'])+_b['hidden']#(batchsize*28,128)
    _H=tf.reshape(_H,[-1,_nsteps,dimhidden])#(batchsize,28,128)
    #算隐层时,是按照整体去算,算完之后因为要进行RNN训练——分切片进行
   # _Hsplit=tf.split(_H,_nsteps,0)
    ''''
    tf.split(dimension, num_split, input):
    dimension的意思就是输入张量的哪一个维度,如果是0就表示对第0维度进行切割。
    num_split就是切割的数量,如果是2就表示输入张量被切成2份,每一份是一个列表。
    '''
    #输入数据的格式的问题
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias=1.0,state_is_tuple=True)#重点
    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
    # dynamic_rnn 接收张量(batch, steps, inputs)或者(steps, batch, inputs)作为X_in
    outputs, final_state=tf.nn.dynamic_rnn(lstm_cell,_H ,initial_state=init_state,time_major=False)#重点

    '''
    tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True): 
    n_hidden表示神经元的个数,
    forget_bias就是LSTM门的忘记系数,如果等于1,就是不会忘记任何信息。如果等于0,就都忘记。
    state_is_tuple默认就是True,官方建议用True,就是表示返回的状态用一个元祖表示。
    '''
    _O=tf.matmul(final_state[1],_W['out'])+_b['out']
    return{
        'X': _X, 'H': _H, 'O': _O
    }
print('NEtwork ready')

learning_rate = 0.001
x      = tf.placeholder("float", [None, nsteps, diminput])
y      = tf.placeholder("float", [None, dimoutput])
myrnn  = _RNN(x, weights, biases, nsteps)
pred   = myrnn['O']
cost   = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optm   = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
accr   = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))

init   = tf.global_variables_initializer()


sess = tf.Session()
sess.run(init)
print ("Start optimization")
for epoch in range(training_epochs):
    avg_cost = 0.
    #total_batch = int(mnist.train.num_examples/batch_size)
    total_batch = 100
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
        # Fit training using batch data
        feeds = {x: batch_xs, y: batch_ys}
        sess.run(optm, feed_dict=feeds)
        # Compute average loss
        avg_cost += sess.run(cost, feed_dict=feeds) / total_batch
        # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
        feeds = {x: batch_xs, y: batch_ys}
        train_acc = sess.run(accr, feed_dict=feeds)
        print(" Training accuracy: %.3f" % (train_acc))
        '''
        testimgs = testimgs.reshape((ntest, nsteps, diminput))
        feeds = {x: testimgs, y: testlabels}
        test_acc = sess.run(accr, feed_dict=feeds)
        print(" Test accuracy: %.3f" % (test_acc))
        '''
    print("Optimization Finished.")
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,126评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,254评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,445评论 0 341
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,185评论 1 278
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,178评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,970评论 1 284
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,276评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,927评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,400评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,883评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,997评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,646评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,213评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,204评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,423评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,423评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,722评论 2 345

推荐阅读更多精彩内容