作为读paper系列的第一篇,算是给新年立个flag吧。以前看东西不习惯整理,到现在遇到瓶颈了,很多东西你看懂了但是没办法说给别人听懂,那说明还是没有理解到位,希望借博客来提高自己的表达和快速梳理的能力吧。本文也是供自己备忘查阅,有错误和不到位的地方希望大家指出。
引言
神经网络虽然强大,但是其训练过程十分复杂,往往涉及到复杂的调参技巧,网络越深就越难调,业内人士常常调侃自己是“炼丹师”。优化困难的原因之一就是网络的输入(或隐藏层输出分布)在不断变化,使得网络得不断地学习“新”的分布,疲于奔命,如果能把变化的数据映射到某个固定的分布上,这将更利于网络的学习和收敛。图像任务中我们会把输入做一个减均值除以方差的预处理,实际上这就是将输入数据映射到了一个更易学习的分布上,而Batch Normalization则是将这种操作拓展到了网络的内部。为了更好地理解Batch Normalization的思想,我们简单补充一下Covariate Shift、Internal Covariate Shift和Whitening的知识,读者也可以参考[2]的介绍。本文分成三个部分,前两个部分作为铺垫,若想直入主题的朋友可直接跳到最后的部分(Batch Normalization)。
(一)Covariate Shift和Internal Covariate Shift
当学习系统的输入分布发生了变化,我们就称发生了Covariate Shift(协变量漂移,简称CS)。在模型的训练和测试两个阶段都可能发生CS,如果在训练过程当中发生了CS,那模型就需要不断地去学习新的分布,降低了学习效率,并且分布的复杂程度可能超过了模型容量,降低了学习效果。如果在测试中发生了CS,则破坏了机器学习中的独立同分布假设(IID),常规的机器学习是要假设训练数据和测试数据的分布是一致的,如果测试集分布和训练集的分布存在差异,则模型精度将会下降,这就像给学生考了一道超纲的题,没学过自然做不出来。
我们可以将CS的问题推广到网络的中间层,[1]将Internal Covariate Shift(内部协变量漂移,简称ICS)定义成:由于网络在训练过程中参数的更新而导致网络隐藏层输出分布的发生变化。比如我们可以将中的看作是的输入,在网络的训练过程中,会随着参数的变化而发生变化,导致F2难以学习。
(二)Whitening和Normalization
Whitening(白化)是一种数据预处理过程,其目的是降低输入数据的冗余性,因为白化后的数据具有白噪声的特征,所以我们称之为“whitening”。白化后的数据具有一致的分布,能够加快网络收敛。
然而因为严格的Whitening涉及到矩阵分析,计算复杂,我们通常会用Normalization(归一化)来替代白化,Normalization就是将分布变换为均值方差一致的分布,我们常用的减均值除以方差就是一种Normalization操作。虽然均值方差一致不等于分布一致,但是因为它也是对数据的分布做了简单约束,能在一定程度上缓解CS或ICS问题,并且它的计算速度快和可微性好,所以Bath Normalization使用的是Normalization而不是严格的Whitening.
另外,关于Normalization后为什么可以加快收敛,还可以参考李宏毅老师网课的BatchNormalization一节,[3]是我在网上找到的一个网友的笔记。
Towards Reducing Internal Covariate Shift
在[1]中Towards Reducing Internal Covariate Shift章节,作者说了他们为什么使用Normalization来缓解ICS的思路,我没怎么读懂,网上相关的解读也很少,但我感觉这一段很重要,科研的过程不仅是要看结果,了解结果产生的过程和思路才是更重要的,比如[4]就接使用Whitening来缓解ICS。我有时间研究清楚之后再来更新,我先把目前的思路写在下面。
既然前人的工作对输入数据做了Whitening并取得了加速收敛的效果,那么一个拍脑袋的缓解ICS的办法就是对每一层的输出也使用Whitening。BN作者刚开始研究的时做了一个简单尝试。考虑某一层的输出为,为该层输入,为可学习参数,遍历整个数据集得到。我们现在要将白化得到(简单起见,只考虑减均值的过程)。在反向传播的时候我们需要更新参数,得到,如果优化器忽略掉和的相关性的话(显然是有关的),会有,那么在下一次训练时,,在随后的输出中,归一化就失效了。随着训练的继续,会越来越大,如果加上尺度归一化的话这个问题会更加严重。
上述的问题就在于优化过程没有意识到归一化。我们在求梯度的时候应该把归一化过程给考虑进去。考虑一个,归一化的输出不仅依赖于当前的输入,还依赖整个数据集。如果要进行严格的白化操作,在反向传播的时候我们就需要求两个雅可比行列式:和,忽略第二个式子就会导致前面说的问题。严格的白化会涉及到复杂计算,效率低,所以本文使用简单的归一化,并且使用mini-batch代替整个数据集来加快计算。
(这一段我也是不知所云,如果谁读懂了希望能够给我指点一下)
Batch Normalization
本章会介绍BN的具体计算方法:归一化+去归一化,然后结合我的实验经验阐述训练和测试中的一些事项。
(一)计算方法
BN的计算可以分成两个步骤:归一化+去归一化。首先是归一化,为了减少ICS,同时使得算法利于集成到训练过程中,Batch Normalization(BN)做了两点简化:1)使用特征Normalization代替Whitening;2)只mini-batch代替整个训练集来计算统计量。对于一个维的特征,BN首先对其进行归一化:
为样本方差,为方差的无偏估计,
m为batch size,是为了防止除以0的一个微小常数。这样变换后得到的服从均值为0,方差为1的分布。
然后是去归一化。去归一化操作实际上就是一个缩放平移操作,如果把网络每一层的特征都归一化会限制网络的学习能力,比如sigmod函数本来可以作为门限函数表示开/关的逻辑信息,但是归一化后就将只能将特征限定在线性区域了,所以作者在归一化后又进行了一次线性变换:,其中和都是可学习的,当,时,整个BN层就是恒等变换。去归一化操作还可以这么去理解,归一化操作限制了网络的学习能力,去归一化操作引入了新的可学习参数,把丢失的学习能力又补了上来。
整个BN的算法流程如下:
读到这里,大家可能会有疑问,因为最后的scale and shift的操作,使得BN的输出并不是归一化的,这不是与减少ICS的初衷相悖吗?同样的问题也在[5]中被指出了。而我的理解是,虽然BN的输出不一定是归一化的,但是它引入的中间量是一个归一化分布,这个变量对于加速网络收敛非常重要。我们可以这样理解:如果把看作输入,将和以及后面的网络都看作是一个用于学习的分布的系统,比起学习更复杂的的分布,学习归一化分布自然更简单。
(二)实验启发
1.关于BN参数的初始化。其实我自己在实验过程中发现有很多工程实际上直接令BN层的beta=0和gamma=1,然后冻结这两个参数,相当于把BN退化成一个特征归一化层,实际效果也很好。另外还有一个经验就是如果要在一些预训练模型上加入BN层,比如VGG出来的时候还没有BN技术,如果我们直接在原有网络上加BN,可能会破坏预训练的状态,那我们可以将参数初始化成,
,这样BN的初态就是一个恒等变换了,可以让网络可以有一个比较好的初态。
2.关于测试阶段。我们会希望测试样本是独立的,也就是说测试样本间不能互相影响,比如我们不希望这次以batch size=1测试得到一个结果A,下次以batch size=2测试,却得到不一样的结果B。BN的计算过程中会涉及到两个统计量:均值(mean)和方差(var)。我们在测试时要使用固定的mean和var,测试阶段一般是使用训练集的mean和var。主流的自动微分框架会在训练时用滑动平均的方式统计mean和var,当你使用测试模式的时候,框架会自动载入统计好的mean与var。比如MXNet会使用下式来统计均值与方差:
通常momentum会设得很大,比如0.99,以便得到更准确的统计值,如果你发现模型在训练和测试时表现差异太大,可以检查下这个值。
3.关于先激活再BN,还是先BN再激活。我把前者简称后置BN,后者简称前置BN。[1]中采用的是前置BN,但实际实验中我发现两者差异并不大(有时后置BN效果反而更好一点)。可能有的人认为只有前置BN才会输出归一化的值,才能有效果,但如我在(一)计算方法一节中说的那样,这两种方法的最终输出都不是归一化的,关键在于两者都引入了归一化的中间量。我在文末提供了相关的拓展阅读,[6]为网友讨论,[7]提供了直观的实验结果。
4.关于BN为什么能够加速收敛。这个问题在[1]的Experiments部分有详细的讨论。(1)BN可以将特征约束在激活函数的线性区域,缓解梯度消失;(2)在网络的多层都加入BN可以缓解训练的发散问题,允许采用更大的学习率;(3)随机采样提供了一定的正则化约束,有防过拟合的效果。
总结
BN通过引入层内的批归一化操作对特征进行归一化,减少ICS,实现了加速网络收敛的效果。但关于BN为什么不直接使用白化操作我还没理解,可能进一步阅读[4]和Group Normalization后能找到答案。
参考资料
[1]Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
[2]Internal Covariate Shift与Normalization
[3]12. 批标准化(Batch Normalization
[4]Decorrelated Batch Normalization
[5]Intro to Optimization in Deep Learning: Busting the Myth About Batch Normalization
[6]Batch Normalization before or after ReLU?
[7]caffenet-benchmark/batchnorm.md