论文笔记 | Neural Random Access Machine

论文信息

项目 内容
作者 Karol Kurach & Marcin Andrychowicz & Ilya Sutskever
发表 ICLR 2016

摘要和前言

本文实现了一个可以操作和读取指针的神经网络架构,称为 Neural Random Access Machine 。其特点是可以操作一个可变大小的外部记忆。通过学习需要操作指针才能完成的任务验证其能力,并且发现模型可以解决此类问题并使用链表、二叉树等结构。对于简单的任务,模型可以泛化到任意长度的序列上。在特定的假设下,记忆可以在常数时间内读取。

作者认为,神经网络的进步来源于:结构更深的同时,参数更少,且可训练。 Neural Turing Machine 和 Grid-LSTM 的成功在于深度、短期记忆的大小和参数数量,三者相互独立。

模型

模型描述

模型有 R 个寄存器,每个寄存器储存一个整数,用 {0, 1, \dots M-1} 上的分布来表示。控制器不能直接访问寄存器,但可以通过一系列预定义的“模块”(module,或称“门”,gate)来与之交互,举例来说,整数加法,等值测试等等。

因此模块记作 m_1, m_2, \dots, m_Q ,且

m_i\ :\ \{0, 1, \dots M-1\} \times \{0, 1, \dots M-1\} \rightarrow \{0, 1, \dots M-1\}

也就是集合上的一个二元运算。

模型每一时间步上进行:

  1. 控制器根据寄存器的值取得一些输入
  2. 控制器更新内部状态(是一个LSTM)
  3. 控制器输出一个“模糊电路”(fuzzy circuit)的描述。包含输入 r_1, \dots, r_R ,门 m_1, \dots, m_QR 个输出
  4. 寄存器的值被模糊电路的输出覆写

其中电路构成如下:

模块 m_i 的输入是控制器从 \{r_1, \dots, r_R, o_1, \dots, o_{i-1}\} 中选出的。其中:

  • r_j 表示当前时间步第 j 个寄存器储存的值
  • o_j 表示当前时间步第 j 个模块的输出

控制器对输入进行加权平均,决定哪些值作为输入。因此,对于 1 \le i \le Q

o_i = m_i\left(\left(r_1, \dots, r_R, o_1, \dots, o_{i-1}\right)^T\textbf{softmax}(a_i), \left(r_1, \dots, r_R, o_1, \dots, o_{i-1}\right)^T\textbf{softmax}(b_i)\right)

其中 a_ib_i 是控制器生成的权重向量。

为使模块接收概率分布输入,并输出一个分布,修改定义如下:

\forall_{0 \le c \lt M}\ \mathbb{P}(m_i(A B)=c) = \displaystyle\sum_{0 \le a, b \lt M}\mathbb{P}(A=a)\mathbb{P}(B=b)[m_i(a, b) = c]

计算完成后,控制器决定哪些结果应该重新存储到寄存器中:

r_i := (r_1, \dots, r_R, o_1, \dots, o_Q)^T\textbf{softmax}(c_i)

其中 c_i 是控制结果储存的权重向量。

每一时间步的开始,控制器接收一些由寄存器决定的输入。朴素的想法可能是将寄存器的值直接作为输入。这样的问题是,如果将整个分布作为输入,模型的参数数量将与 M (即寄存器的取值上限)有关。下一节将把 M 联系到一个外部 RAM 上,因此会妨碍模型泛化到不同的存储大小上。

因此对于每个寄存器,我们只输出一个标量,\mathbb{P}(r_i=0) 。这种设计也有一个优势,即限制控制器得到的输入信息量,强制它使用模块解决问题,而非自己解决。特别地,如果 r_i \in \{0, 1\} ,该标量保留了全部的信息。如果 r_i 是一个布尔模块的输出,那么它就属于这种情况。例如,不等值测试模块 m_i(a, b)=[a \lt b]

记忆磁带

如果将寄存器初始化为一个输入的序列,在一定时间步后,模型将输出序列产生到寄存器里,那么可以描述一个 seq-to-seq 模型。这种使用方式的缺点在于,无法泛化到长序列上,因为可处理的序列的长度等于寄存器数,而它是一个常数。

因此,设计一个长度为 M 的记忆磁带,每个位置上是一个记忆单元。每个记忆单元储存一个 \{0, 1, \dots M-1\} 的分布。这一内容又可解释为一个磁带上的模糊指针。记忆的准确状态可以用矩阵 \mathcal{M} \in \mathbb{R}_M^M 来描述。\mathcal{M}_{i,j} 表示第 i 个记忆单元存储值 j 的概率。

模块仅使用两种模块和记忆磁带交互:

  1. READ ,接收一个参数作为输入(忽略第二个输入参数),输出记忆磁带该地址上的值。通过与上面类似的方法扩展定义到分布上。具体来讲,对于输入的模糊指针 p ,模块输出 \mathcal{M}^Tp
  2. WRITE ,接收输入指针 p 和值 a ,将指针 p 处的值替换为 a 。数学表示是 \mathcal{M} := (J-p) J^T \cdot \mathcal{M} + pa^T 。其中 JM1 组成的列向量, \cdot 表示按元素相乘。

记忆磁带同时也是一个输入/输出通道。记忆初始化成一个输入序列,希望模型将输出写到记忆中。

此外,每个时间步,控制器输出一个结束的概率 f_t = \textbf{sigmoid}(x_t) \in [0, 1] 。运行在时间步 t 前没有结束的概率是 \prod_{i=1}^{t-1}(1-f_i) ,恰好在时间步 t 输出结果的概率是 p_t = f_t\cdot\prod_{i=1}^{t-1}(1-f_i) 。还有一个超参数,最长时间步数 T 。如果该步没有结束,模型需要强制输出,即 p_T = 1 - \sum_{i=1}^{T-1}p_i

\mathcal{M}^{(t)} 表示第 t 个时间步的记忆矩阵。对于输入输出对 (x, y) ,其中 x, y \in \{0, 1, \dots M-1\}^M,当记忆被初始化为 x 时,定义损失函数为 -\sum_{t=1}^{T}\left(p_t\cdot \sum_{i=1}^{M}\log\left(\mathcal{M}^{(t)}_{i,y_i}\right)\right) 。或者使用对数似然函数定义损失函数,即 -\sum_{t=1}^{T}\log \left(\sum_{i=1}^{M}p_t\cdot\mathcal{M}^{(t)}_{i,y_i}\right)

此外,对于我们考虑的问题而言,输出序列通常比记忆短。我们可以在记忆单元上计算损失函数,因为输出已经被包含在内了。

离散化

在分布上进行计算复杂度很高,比如计算 READ 的时间复杂度是 \Theta(M^2) 。人们可能会想(我们在后面用实验证明了)中间值的分布具有很低的熵。在训练过后,我们使用一个离散化的模型进行推理。也就是只选取最有可能的输入,以及输出。具体来讲,就是把上面的 \textbf{softmax} 换成在最大值上输出 1 ,其他位置输出 0 的向量的函数。

离散化的模型每个寄存器和记忆单元中都储存一个 \{0, 1, \dots M-1\} 的整数。因此可以加速。

如果只替换 softmax 的话,寄存器和记忆单元仍可以是分布。根据上下文,此处离散化还包括将所有分布经过一个相同的离散化函数。

对于一个前馈控制器,以及较少数量的寄存器(比如小于20),推理可以进一步加速。因为控制器的输入仅为一些二进制的值,我们可以提前把每种配置都计算出来。

同上,控制器的输入仍可能是 0 到 1 的概率。

实验

训练中使用的技术有 Curriculum Learning [1] 、梯度截断、梯度随机噪声、更新权重后调整分布以使其仍然表示整数的概率分布、对输出的熵过低进行逐步递减的惩罚、限制 \log 计算以防止溢出。

这里介绍一下 Curriculum Learning 。

Continuation Method

为了求解非凸优化问题,我们可以使用 Continuation Method (CM)。基本思想是先计算一个平滑版本的问题,再逐渐降低平滑性。这里利用的直觉是,平滑版本的问题展现了全局特点。这种方法中,需要定义一系列的单参数的损失函数, C_\lambda(\theta)C_0 是一个容易优化的高度平滑的版本, C_1 是我们希望优化的版本。

从抽象的层次来看, CM 也是一系列训练标准。序列中的每一个训练标准都为样本设定了不同的权重,或者更一般地,重新为训练分布设置权重。最初,权重倾向于“简单的”样本,或者那些展示了简单概念的样本。序列中的下一个标准,将越来越提高较难样本的采样概率。序列的末尾,我们在训练样本上均匀采样,因此训练数据的分布就是原始的训练分布。

形式化表示如下:

z 是表示示例的随机变量(有监督学习中可能是 (x,y) 对),P(z) 是学习者最终应该学习到的训练样本分布。0 \le W_\lambda(z) \le 1 是在 \lambda 步分给 z 样本的权重,且 W_1(z) = 1 。对应的训练分布即

Q_\lambda(z) \propto W_\lambda(z)P(z) \ \forall_z

且使得 \int Q(z)dz = 1 ,因此

Q_1(z)=P(z) \ \forall_z

考虑从 \lambda = 0\lambda = 1 的单调递增序列。

定义:如果 Q_\lambda 的熵递增,则称其为一个 Curriculum 。即

H(Q_\lambda) < H(Q_{\lambda+\epsilon}) \ \forall_{\epsilon > 0}

并且

W_{\lambda+\epsilon}(z) \ge W_\lambda(z) \ \forall_z, \forall_{\epsilon > 0}

考虑 Q_\lambda 是有限集上的样例,这一过程对应于增加新的样本。某些实验中,仅仅将训练集划分为简单和完整两步就可以得到提升。另一个极端是随机采样。此时困难样本的概率逐渐增加,直到最后所有样本概率相等,均为 1 。

具体到本篇论文中,以序列的长度或者树的大小作为训练复杂度。

每次训练时,样本从一个由难度 D 决定的分布中采样得到。每当错误率降低到一定阈值以下,就提高难度,直到最大值。

具体的采样方法是:

首先从一个由 D 决定的分布中采样得到 d

  • 10%: 从所有可能难度中均匀采样
  • 25%: 从 [1, D+e] 中均匀采样,其中 e 服从每次实验成功概率为 \frac{1}{2} 的几何分布
  • 65%: d = D + e

再使用难度为 d 的样本作为训练样本的训练复杂度。

任务

选取的任务如下:

  1. Access: Given a value k and an array A, return A[k].
  2. Increment: Given an array, increment all its elements by 1.
  3. Copy: Given an array and a pointer to the destination, copy all elements from the array to the given location.
  4. Reverse: Given an array and a pointer to the destination, copy all elements from the array in reversed order.
  5. Swap: Given two pointers p, q and an array A, swap elements A[p] and A[q].
  6. Permutation: Given two arrays of n elements: P (contains a permutation of numbers (1, \dots, n) and A (contains random elements), permutate A according to P.
  7. ListK: Given a pointer to the head of a linked list and a number k, find the value of the k-th element on the list.
  8. ListSearch: Given a pointer to the head of a linked list and a value v to find return a pointer to the first node on the list with the value v.
  9. Merge: Given pointers to 2 sorted arrays A and B, merge them.
  10. WalkBST: Given a pointer to the root of a Binary Search Tree, and a path to be traversed (sequence of left/right steps), return the element at the end of the path.

模块

所有的模块都需要事先指定类型和顺序,本次实验中使用的如下:

  • READ
  • ZERO(a, b) = 0
  • ONE(a, b) = 1
  • TWO(a, b) = 2
  • INC(a, b) = (a+1) \mod M
  • ADD(a, b) = (a+b) \mod M
  • SUB(a, b) = (a−b) \mod M
  • DEC(a, b) = (a−1) \mod M
  • LESS-THAN(a, b) = [a < b]
  • LESS-OR-EQUAL-THAN(a, b) = [a \le b]
  • EQUALITY-TEST(a, b) = [a = b]
  • MIN(a, b) = \min(a, b)$
  • MAX(a, b) = \max(a, b)$
  • WRITE

实验结果

简单任务

前五个任务被划分为简单任务,因为在训练和测试中均达到了 0 错误率。而且训练结果泛化到序列长度为 50 也是 0 错误率。更进一步地,CopyIncrement 被验证可以泛化到任意长度。而对模型进行离散化也不会影响其表现。

让我们分析一下 Copy 的记忆、寄存器以及产生的电路图。

其中电路图是第二步之后的每一步。可以看到此时 r2 储存了转移的长度,每次更新到 r2 自己中,因此保持不变。r3 是累加器,每次进行加一后与 r2 中的较小值存到 r4 中。r4 代表当前读的地址,与 r2 相加后得到写的地址,因此二者通过一次读写完成复制。

因此,每两步(一步 r3 自增存储到 r4 直到与 r2 相等,另一步 r4 实际进行读写)完成一个元素的复制。

可以看到上面的电路持续产生地址常数 0 ,作为写的目的地址。

可以看到 r5 作为读写地址,每次由 r1 递增 1 更新到自己,并实现更新。

可以看到 r3 作为读地址递增,每次用目的地的 2 倍减 r3 减 1 作为写地址(注:实际上只对特定目的地址情况成立)。

困难任务

为了解决困难任务,引入了上面说的很多技术。最终在训练数据中把除了 WalkBSTMerge 的错误率调到了 0 。而另两个则调到了 1% 以下。

泛化较好的任务是 PermutationListKWalkBST 。离散化则只有 Permutation 没有损失性能。其余的错误率高达 70% 以上。

与已有模型比较

NTM 缺乏将一个指针储存在记忆中的自然的方式。因此作者估计其能完成 CopyReverse 这样的任务,而难以完成 ListKListSearchWalkBST 这样的涉及到指针的任务。

NRAM 的一个特点是缺乏基于内容的寻址,这是有意为之的,目的是加速内存访问。

结论

NRAM 可以解决一些算法类问题。部分解决方法可以泛化到任意序列长度。

参考链接


  1. Bengio
    Yoshua, et al. "Curriculum learning." Proceedings of the 26th annual international conference on machine learning. ACM, 2009.

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,142评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,298评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,068评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,081评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,099评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,071评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,990评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,832评论 0 273
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,274评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,488评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,649评论 1 347
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,378评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,979评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,625评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,796评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,643评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,545评论 2 352

推荐阅读更多精彩内容