原因一:除0错误
- 数据原因:由于路径或脏数据等原因,造成数据读取出差
解决方法:判断出现nan的数据的id,剔除,简单粗暴 - 代码原因:数据本身就存在0值,代码在执行过程中将其置为分母
注意:很多代码都存在隐式除0操作,因为现成的损失函数有不少采用了log函数,如CrossEntropyLoss和log_softmax。使用log函数的优点可以拖到最底部,后文会描述。
y = log(softmax(x))
y' = 1/softmax(x) ##出现除0操作,softmax的值域区间为[0,1]
解决方法: 令y= log(softmax(x)+EPS), 其中EPS可以取1e-12(极小值)
原因二:学习策略、超参设置不当
- 学习速率过大,尝试调小
- batch_size过大,尝试调小
- 尝试使用batch normalization和instance normalization
为什么很多损失函数都以log函数为原型?
令X表示我们模型的输出,而我们希望的预测类别是c,
则有预测概率P=softmax(X)
其中Pc:输出X属于c类的概率
一般来说,我们希望输出概率Pc越大越好,如何优化?
最简单的做法就是令损失函数
loss_f = -Pc
这样的做法在理解上十分直观,当我们优化loss_f取得最小值时,即Pc取得最大值
但我们还是需要使用log函数!!!
为方便讲解,我们暂时忽略损失函数前面的符号,
令loss_f = Pc
loss_g = log(Pc)
- 等价性
loss_f与loss_g均为单调递增函数,优化loss_f与loss_g等价 - 惩罚力度
loss_f是斜率为1的直线函数,具有梯度不变形,意味着不管Pc的值为0.9还是0.1,惩罚力度都是一样的
而loss_g是曲线函数,接近0的地方值接近”无穷大“,接近1的地方值接近0,意味着当Pc=0.9时惩罚力度更小,而Pc接近0的时候表示偏差太大,惩罚力度非常大。