本节是自动求导框架技术的第二节,本系列其余文章包括
1. 单变量的链式法则
链式法则是神经网络反向传播算法的基础,参考多元复合函数求导链式法则,对于下面一组函数:
变量 w 对于 x 和 y 这两个变量的导数求解过程为:
2. 基于张量的链式法则
如果把上面的函数中的变量 x,y,z,u,v,w 换成矩阵或者张量,上述链式法则依然成立。求导的方式参考矩阵求导。为了方便理解,矩阵求导过程可以省略合并那一步。链式法则中的乘法改为矩阵乘法。读者可以自己手动验证一下该过程的正确性。
3. 计算图与前向传播
实际上,上面一组函数可以用更直观的计算图的方式表达出来,计算图是一种有向无环图(DAG),上述四个函数用计算图表达的结果如下:
使用计算图之后,每个节点可以看作一个产生运算结果的操作节点 operator_node,在每个操作节点中包含了该操作节点的输出 output 矩阵和 这个节点的的操作方法 op (), 其中 op () 方法会调用当前节点所依赖的节点的数据来完成当前节点 output 的计算。那么前向传播过程,也就是已知 x,y 求解 w 的过程可以通过以下方式完成:
1. 初始化 x,y 这两个 operator_node 的 output;
2. 对计算图进行拓扑排序,保证每个节点调用 op () 函数的时候其依赖的节点都已经计算出输出。对于排序后的节点,依次调用其中的 op () 函数,最后就可以得到 w 节点的输出。
4. 计算图与反向传播
反向传播就是计算每个计算节点的梯度,用于更新参数的值。以上面计算图为例子,反向传播的过程中每个计算节点增加一个叫 sum_grad 的矩阵用于存储操作节点 w 对于其余每个操作节点的导数矩阵,同时添加一个 grad_op () 函数用于计算每个节点对其依赖节点的导数矩阵 grad;假设 w 这个终节点对当前节点(比如 u 节点)的导数矩阵是 now_sum_grad,终节点 w 对当前节点所依赖的节点(比如 z 节点)的导数矩阵为 dep_sum_grad,那么 dep_sum_grad 的计算过程可以写为:
这就是链式法则在计算图中的实现。反向传播的过程为:
1. 对计算图转置(就是把图中的边反向),然后进行拓扑排序。对排序后的每个节点依次调用 grad_op () 函数,进而使用上面的链式法则计算出每个节点的 sum_grad 矩阵;
2. 用 sum_grad 矩阵去更新需要更新参数的节点。