TensorFlow从0到1 - 15 - 重新思考神经网络初始化

TensorFlow从0到1系列回顾

上一篇14 交叉熵损失函数——克服学习缓慢从最优化算法层面入手,将二次的均方误差(MSE)更换为交叉熵作为损失函数,避免了当出现“严重错误”时导致的学习缓慢。

本篇引入1/sqrt(nin)权重初始化方法,从另一个层面——参数初始化(神经网络调教的5个层面归纳在13 AI驯兽师:神经网络调教综述)入手改善网络的学习速度。

相比之前采用的标准正态分布初始化,1/sqrt(nin)权重初始化不仅明显的加快了学习速度,而且单纯性(其他任何参数不变)的提升了测试集识别精度1~2个百分点。

理解了1/sqrt(nin)权重初始化的思想,就能很容易的理解Xavier、He权重初始化方法。

参数初始化之“重”

神经网络的训练过程,就是自动调整网络中参数的过程。在训练的起初,网络的参数总要从某一状态开始,而这个初始状态的设定,就是神经网络的初始化。

之所以要重新思考神经网络权重和偏置的初始化,是因为它对于后续的训练非常重要。

12 TensorFlow构建3层NN玩转MNIST中就踩了“参数初始化的坑”:简单将权重和偏置初始化为0,导致了网络训练陷入了一个局部最优沼泽而无法自拔,最终识别率仅为60%。

不仅有“局部最优”的坑,在14 交叉熵损失函数——防止学习缓慢还见识了初始化导致“神经元饱和”的坑。

合适网络初始值,不仅有助于梯度下降法在一个好的“起点”上去寻找最优值,还能避免神经元发生学习饱和

重新审视标准正态分布

Initialization

在之前实现的MNIST数字识别案例中,权重和偏置的初始化采用的是符合均值为0、标准差为1的标准正态分布(Standard Noraml Distribution)随机化方法。基于它的训练过程还算平稳。但它是最佳的初始化策略吗?

它如此“特别”更像是一个警告:我们并不总能轻易的得到最佳答案,一定还有“压榨”的空间。

一个寻找切入点的常用方法,就是人为诱导其产生问题。让一个具有1000个神经元输入层的网络,以标准正态分布做随机初始化,然后人造干预:令输入层神经元一半(500个)值为1,另一半(另500个)值为0。现在聚焦到接下来隐藏层中的一个神经元:

隐藏层神经元

如上图所示,1000个输入层神经元全部连接到了隐藏层的第一个神经元。此时考察神经元的加权和z = ∑jwjxj + b:

  • 将z的表达式展开,初始共有1001项(不要漏掉偏置b);
  • 人为令输入xj中的500个为0,所以z的表达式最终有501项;
  • 人为令输入xj的其余500个为1,所以z由500项wj和1项b组成,它们符合标准正态分布N(0,1);
  • 推导得到z符合均值为0,标准差为√501(501的平方根)的正态分布,推导过程稍后解释;

通过人为设置特殊的输入,由权重w和偏置b的统计分布,得到了z的统计分布:

z的分布

从图中可见,由于标准差√501非常大,导致z的分布从-30到30出现的比例都很高,也就是说,∣z∣ >> 1出现的概率非常大。还记得Sigmoid曲线吗?当∣z∣ >> 1时,σ'(z)就会变得非常小,神经元学习饱和。

Sigmoid

类似的,网络中后续层中的神经元也有同样的性质。

虽然是人为制造特殊的输入数据暴露了网络的问题,但是从中可以得到一个启示:如果网络的权重和偏置采用N(0,1)初始化,那么网络中各层的神经元数量n越多,造成后续层神经元加权和z的标准差越大,∣z∣ >> 1出现的概率也越大,最终造成神经元饱和——学习缓慢

1/sqrt(nin)权重初始化

顺着上面的分析,一个比较自然的思路是:既然神经元加权和z的标准差与网络上一层神经元的数量nin有相关性,那么为了抵消掉神经元数量的影响,初始化分布的标准差就不应该是一个常数。

本篇引入的1/sqrt(nin)权重初始化就是答案所在:使用均值为0,标准差为1/sqrt(nin)的正态分布来初始化权重。sqrt表示开根号,同√。

继续使用之前的人为输入数据和网络架构:

  • 将z的表达式展开,初始共有1001项;
  • 人为令输入xj中的500个为0,所以z的表达式最终有501项;
  • 人为令输入xj的其余500个为1,所以z由500项wj和1项b组成,它们符合正态分布N(0,1/sqrt(nin));
  • 推导得到z符合均值为0,标准差为√(3/2)(3/2的平方根)正态分布,推导过程稍后解释;

得到了新的z的统计分布:

z的分布

此时的正态曲线变的非常尖锐,z的可能取值都在0附近,再看Sigmoid曲线就会发现,z在0附近时,σ(z)曲线最“陡”,σ'(z)值越大,学习速度越快。

注意一点,由于神经元的偏置b对于加权和z的贡献不受上一层神经元数量nin的影响,所以偏置b的初始化可以沿用之前的N(0,1)。

z的概率分布推导

回来解释下已知w和b的分布,如何计算z = ∑jwj + b的分布(xj为1,故省略)。

先准备两个特性:

  • 独立随机变量和的方差,是每个独立随机变量方差的和
  • 方差是标准差的平方

权重和偏置分布为N(0,1)情况下的推导:

  • 已知wj和b的标准差是1,那么wj和b的方差也是1;
  • 由于人为输入,z的展开式有501=1000/2+1项,每项标准差为1;
  • z的方差 = 12 x 501;
  • z的标准差 = √501;

权重分布为N(0,1/sqrt(nin)),偏置分布为N(0,1)情况下的推导:

  • 已知wj的标准差是1/sqrt(nin),那么wj的方差是1/nin,已知b的标准差是1,那么b的方差也是1;
  • 由于人为输入,z的展开式有nin/2+1项,前nin/2项为权重wj,每项方差为1/nin,最后1项为偏置b,方差为1;
  • z的方差 = 1/n x n/2 + 1 = 3/2;
  • z的标准差 = √(3/2);

结果对比

本篇基于12 TensorFlow构建3层NN玩转MNIST中的实现,单纯性的使用N(0,1/sqrt(nin))权重初始化与前者进行了对比,结果如下:

N(0,1)参数初始化
N(0,1/sqrt(n))

从输出明显看出,采用N(0,1/sqrt(nin))权重初始化的学习速度明显快了很多,第一次迭代Epoch 0就获得了94%的识别率,而前面的N(0,1)实现到Epoch 7才达到了94%。

不仅学习速率变快,30次迭代结束后,采用N(0,1/sqrt(nin))权重初始化的识别精度比前者高了1个百分点,达到了96%以上。

小结

本篇引入1/sqrt(nin)权重初始化方法,改变了神经元加权和z的随机概率分布,有效的避免了神经元饱和,最终不仅加快了学习速率,而且网络的性能也有明显的改善。

有很多其他的权重初始化方法,比如Xavier、He等,其基本思想都是相似的。

附完整代码

N(0,1/sqrt(nin))权重初始化的有效性分析,花了我们不少功夫,但是代码实现却异常简洁:

W_2 = tf.Variable(tf.random_normal([784, 30]) / tf.sqrt(784.0))
...
W_3 = tf.Variable(tf.random_normal([30, 10]) / tf.sqrt(30.0))

完整代码:

import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

FLAGS = None


def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    W_2 = tf.Variable(tf.random_normal([784, 30]) / tf.sqrt(784.0))
    b_2 = tf.Variable(tf.random_normal([30]))
    z_2 = tf.matmul(x, W_2) + b_2
    a_2 = tf.sigmoid(z_2)

    W_3 = tf.Variable(tf.random_normal([30, 10]) / tf.sqrt(30.0))
    b_3 = tf.Variable(tf.random_normal([10]))
    z_3 = tf.matmul(a_2, W_3) + b_3
    a_3 = tf.sigmoid(z_3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    loss = tf.reduce_mean(tf.norm(y_ - a_3, axis=1)**2) / 2
    # loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
    train_step = tf.train.GradientDescentOptimizer(3.0).minimize(loss)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    # Train
    best = 0
    for epoch in range(30):
        for _ in range(5000):
            batch_xs, batch_ys = mnist.train.next_batch(10)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
        accuracy_currut = sess.run(accuracy, feed_dict={x: mnist.test.images,
                                                        y_: mnist.test.labels})
        print("Epoch %s: %s / 10000" % (epoch, accuracy_currut))
        best = (best, accuracy_currut)[best <= accuracy_currut]

    # Test trained model
    print("best: %s / 10000" % best)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='/MNIST/',
                        help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

下载 tf_15_mnist_nn_weight_init.py

上一篇 14 交叉熵损失函数——防止学习缓慢
下一篇 16 L2正则化对抗“过拟合”


共享协议:署名-非商业性使用-禁止演绎(CC BY-NC-ND 3.0 CN)
转载请注明:作者黑猿大叔(简书)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 196,583评论 5 462
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 82,669评论 2 374
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 143,684评论 0 325
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,682评论 1 267
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,533评论 5 358
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,396评论 1 275
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,814评论 3 387
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,458评论 0 255
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,745评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,789评论 2 314
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,565评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,410评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,828评论 3 300
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,050评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,342评论 1 253
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,793评论 2 342
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,010评论 2 337

推荐阅读更多精彩内容