一个深度学习的框架,在训练过程中发现,前向的运算部分没问题,但是求导回传后,就出现整个网络失效的问题(整个网络权重全部更新为0),检查了很久以后,找到一个容易被忽视的问题。
就现在很多模型中所用到的loss公式中大多数都是基本初等函数,这些基本初等函数中,存在一些定义域为R,但是其导数定义域不是R的,或者是函数定义域与其导数定义域不相同的。例如,我在应用中用到的求平方根的函数,该函数定义域为非负实数,但是该函数导数的定义域为正实数,这样导致了,一旦出现变量为0,那么其导数就为不可求的状态,进而导致了回传更新参数出错。
修改解决办法很简单,比如在求平方根函数的调用中加入很小的数,保证其可导,例如:
tf.sqrt(x+1e-20)
但是这样的解决只是一个trick的解决办法,与此同时,你需要关注的是,你代码计算出来的x,是否应该本身的值域就是x>0,那么这里出现0值,一定是前面计算存在问题;第二就是如果x的值域本身就包含0,那么有多大的概率会导致计算取0。这些都是通过上述折中的办法解决问题时,应该非常谨慎注意的问题。
与此同时,需要探讨的是TensorFlow框架实现的问题,为什么在导数出现错误的情况下,还允许训练继续,更新错误的权重,而不是抛出异常,这是值得探讨的问题。