MPNN:消息传递神经网络

论文标题:Neural Message Passing for Quantum Chemistry
论文链接:https://arxiv.org/abs/1704.01212
论文来源:ICML 2017

一、概述

近年来,随着量子化学计算和分子动力学模拟等实验的展开产生了巨大的数据量,大多数经典的机器学习技术都无法有效利用目前的数据。而原子系统的对称性表明,能够应用于网络图中的神经网络也能够应用于分子模型。所以,找到一个更加强大的模型来解决目前的化学任务可以等价于找到一个适用于网络图的模型。

本文的目标是证明:能够应用于化学预测任务的模型可以直接从分子图中学习到分子的特征,并且不受到图同构的影响。本文提出的MPNN是一种用于图上监督学习的框架,能够概括之前一些文献提出的一些方法,并且能够按照这个框架提出一些新的架构。本文提出的新的MPNN变种基于实际的应用场景:预测有机小分子的量子力学性质。并且,作者希望以后提出的新的MPNN变种能够从实际的应用出发,从实际应用中获得启发。

本文以QM9作为benchmark数据集,这个数据集由130k个分子组成,每个分子有13个性质,这些性质是通过一种计算昂贵的量子力学模拟方法(DFT)近似生成的,相当于13个回归任务。这些任务似乎代表了许多重要的化学预测问题,并且目前对许多现有方法来说是困难的。

本文提出的模型的性能度量采用两种形式:
①DFT近似的平均估计误差;
②化学界已经确立的目标误差,称为“化学精度”。

二、消息传递神经网络

本文介绍了能够应用MPNN框架的8篇文献,为了简便起见,以处理无向图G为例,无向图G包含节点特征x_v和边的特征e_{vw},将这种形式推广到有向重图是不重要的。MPNN前向传播的过程包含两个阶段,即消息传递阶段(message passing phase)和读出阶段(readout phase)。消息传递阶段运行T个时间步并且依赖消息函数M_t以及节点更新函数U_t。在消息传递阶段,每个节点的隐状态h_{v}^{t}都会根据消息m_{v}^{t+1}进行更新,具体过程是:

m_{v}^{t+1}=\sum _{w\in N(v)}M_{t}(h_{v}^{t},h_{w}^{t},e_{vw})\\ h_{v}^{t+1}=U_{t}(h_{v}^{t},m_{v}^{t+1})

N(v)代表节点v的邻居节点集合。读出阶段使用某种读出函数R来为整个图计算一个特征向量:

\hat{y}=R(\left \{h_{v}^{T}|v\in G\right \})

M_t,U_t,R都是用来学习的可微函数。R作用于节点状态集合,并且必须对节点状态的排列保持不变,以使MPNN对图同构保持不变。注意MPNN也可以学习边的特征,这可以通过为每条边引入隐状态h_{e_{vw}}^t并应用前面的两个过程来实现。接下来,我们通过指定所使用的消息函数M_t、顶点更新函数U_t和读出函数R来定义以前文献中的模型。

  1. Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)

本文提出的模型采用的消息函数是:

M(h_{v},h_{w},e_{vw})=(h_{w},e_{vw})

(\cdot ,\cdot )代表拼接。节点更新函数是:

U_{t}(h_{v}^{t},m_{v}^{t+1})=\sigma (H_{t}^{deg(v)}m_{v}^{t+1})

deg(v)是节点v的度,H_{t}^{N}对应于时间步t以及节点度N的学习矩阵。读出函数将之前所有隐状态h_{v}^{t}进行连接:

f\left (\sum _{v,t}softmax(W_{t}h_{v}^{t})\right )

f是一个神经网络,W_{t}是时间步t的一个学习矩阵。

这样的消息传递的方法可能有问题,因为最终得到的消息向量为m_{v}^{t+1}=(\sum h_{w}^{t},\sum e_{vw}),这是边和节点状态向量的加和,缺乏边和节点状态向量的交互。

  1. Gated Graph Neural Networks (GG-NN), Li et al.(2016)

消息函数为:

M_{t}(h_{v}^{t},h_{w}^{t},e_{vw})=A_{e_{vw}}h_{w}^{t}

A_{e_{vw}}是特定于边的标签的学习矩阵(这个模型假设边有离散的标签)。更新函数如下:

U_{t}=GRU(h_{v}^{t},m_{v}^{t+1})

GRU就是门控循环单元,一种循环神经网络,对于每个时间步进行权重共享,也就是说每个时间步共用同一个更新函数。最后,读出函数:

R=\sum _{v\in V}\sigma \left (i(h_{v}^{(T)},h_{v}^{0})\right )\odot \left (j(h_{v}^{(T)})\right )

i,j代表神经网络,\odot代表哈达玛积。

  1. Interaction Networks, Battaglia et al. (2016)

这个模型考虑了两种情况,一种是每个节点都有自己的目标,另一种是有一个graph level的目标。它还考虑了在每个时间步骤中存在node level影响的情况,在这种情况下,更新函数将(h_v, x_v, m_v)连接作为输入,其中x_v是一个外部向量,表示顶点v受到的外部影响。消息函数M(h_{v},h_{w},e_{vw})是一个神经网络,使用拼接向量(h_v, h_w, m_v)作为输入,节点更新函数M(h_v, x_v, m_v)也是一个神经网络,使用(h_v, x_v, m_v)作为输入。最终读出函数得到一个graph level的输出:R=f(\sum _{v\in G}h_{v}^{T}),这里f是一个神经网络。注意,这个模型只定义了T=1的情况。

  1. Molecular Graph Convolutions, Kearnes et al. (2016)

这个模型与之前的MPNNs稍微有一些不同,是因为它引入了边的表示e_{vw}^{t},并且会在消息传递阶段进行更新。消息函数为:

M(h_{v}^{t},h_{w}^{t},e_{vw}^{t})=e_{vw}^{t}

节点更新函数为:

U_{t}(h_{v}^{t},m_{v}^{t+1})=\alpha \left (W_{1}(\alpha (W_{0}h_{v}^{t}),m_{v}^{t+1})\right )

同样的(\cdot ,\cdot )代表拼接,\alpha代表ReLU激活函数,W_1,W_0是学习权重矩阵。边状态更新的方式是:

e_{vw}^{t+1}=U_{t}^{'}(e_{vw}^{t},h_{v}^{t},h_{w}^{t})=\alpha \left (W_{4}(\alpha (W_{2},e_{vw}^{t}),\alpha (W_{3}(h_{v}^{t},h_{w}^{t})))\right )

W_i都是学习矩阵。

  1. Deep Tensor Neural Networks, Sch¨utt et al. (2017)

消息函数为:

M_{t}=tanh\left (W^{fc}((W^{cf}h_{w}^{t}+b_{1})\odot (W^{df}e_{vw}+b_{2}))\right )

W^{fc},W^{cf},W^{df}是矩阵,b_{1},b_{2}是偏置向量。更新函数为:

U_{t}(h_{v}^{t},m_{v}^{t+1})=h_{v}^{t}+m_{v}^{t+1}

读出函数使用单个隐层神经网络独立地通过每个节点,并对输出进行求和:

R=\sum _{v}NN(h_{v}^{T})

  1. Laplacian Based Methods, Bruna et al. (2013); Deffer�rard et al. (2016); Kipf & Welling (2016)

8篇文献中有3篇属于这一类。其中两篇采用消息函数:

M_{t}(h_{v}^{t},h_{w}^{t})=C_{vw}^{t}h_{w}^{t}

矩阵C_{vw}^{t}通过拉普拉斯矩阵的特征向量和模型的学习参数来参数化。更新函数为:

U_{t}(h_{v}^{t},m_{v}^{t+1})=\sigma (m_{v}^{t+1})

\sigma代表非线性函数,比如ReLU激活函数。

另一篇文献采用消息函数:

M_{t}(h_{v}^{t},h_{w}^{t})=c_{vw}h_{w}^{t}

这里c_{vw}=(deg(v)deg(w))^{-1/2}A_{vw}。节点更新函数为:

U_{t}(h_{v}^{t},m_{v}^{t+1})=ReLU(W^{t}m_{v}^{t+1})

三、MPNN变种

本文以前述GG-NN作为baseline进行改进,提出一种新的MPNN变种。下文中以d代表节点特征的维度,以n代表图的节点的数量。这一变种适用于有向图,这意味着入边和出边有分别的信息通道,那么信息m_{v}m_{v}^{in}m_{v}^{out}拼接而成,当我们将模型应用无向图时,就把无向图的边看做两条边,包含一条入边,一条出边,有相同的标签,这样处理的方式意味着信息通道的大小是2d而不是d

模型的输入是每个节点的特征向量x_v以及邻接矩阵A,邻接矩阵A具有向量分量,表示分子中的不同化学键以及两个原子之间的成对空间距离。初始状态h_v^0是原子输入特征集合x_v,并且需要padding到维度d。在实验中的每个时间步t都要进行权重共享,并且更新函数采用GRU。

  1. 消息函数

GG-NN原本采用的消息函数,采用矩阵相乘的方式(注意原来的GG-NN的边有离散的标签,而现在我们假设的边有一个特征向量e_{vw}):

M(h_{v},h_{w},e_{vw})=A_{e_{vw}}h_{w}

A_{e_{vw}}是特定于边的标签的学习矩阵。为了兼容边特征,本文提出了新的消息函数:

M(h_{v},h_{w},e_{vw})=A(e_{vw})h_{w}

A(e_{vw})是一个神经网络,将边的特征向量e_{vw}映射到一个d\times d的矩阵。上述两种消息函数的特点是消息只依赖于h_we_{vw}而不依赖于h_v,如果消息同时依赖目标节点与源节点,那么应该是更加高效的,可以尝试以下这种消息函数:

m_{wv}=f(h_{w}^{t},h_{v}^{t},e_{vw})

这里f是一个神经网络。

对于有向图,一共有两个消息函数M^{in}M^{out},对于边e_{vw}应用哪个消息函数取决于边的方向。

  1. 虚拟图形元素

本文探索了两种方式来改变模型中信息的传递。第一种是为未连接的节点对添加一个单独的“虚拟”边类型。这一处理可以在预处理时实现,效果是可以使得在传播过程中让信息传播更远的距离。

另一种方式是添加一个“master”节点,让它通过一种特殊类型的边与所有节点连接。“master”节点充当全局暂存空间,每个节点在消息传递的每个步骤中都对其进行读写操作。另外“master”节点拥有单独的节点维度d_{master},以及内部更新函数(实验中是GRU)的单独权重。这同样可以使得在传播过程中让信息传播更远的距离。这样可以允许模型有更大的容量,同时也不会过多的损失效率,其复杂度为O(|E|d^{2}+nd_{master}^{2})

  1. 读出函数

读出函数采用set2set模型,这个模型使用T=\left \{(h_{v}^{T},x_{v})\right \}作为输入,然后再经过M步计算后生成一个graph level的embeddingq_{t}^{*},其中过程与T内节点顺序无关,最终将q_{t}^{*}输入到一个神经网络中来获得最终输出。具体参考文献: Sequence to sequence for sets。

  1. Multiple Towers

由于消息传递阶段的复杂度为O(n^{2}d^{2}),当nd增大时,计算上就会是昂贵的。处理的方法是将h_{v}^{t}拆分成k个不同的d/k维的embeddingh_{v}^{t,k},并且在每个h_{v}^{t,k}上独立运行传播过程得到\left \{\tilde{h}_{v}^{t+1,k},v\in G\right \},然后进行混合:

(h_{v}^{t,1},h_{v}^{t,2},\cdots ,h_{v}^{t,k})=g(\tilde{h}_{v}^{t,1},\tilde{h}_{v}^{t,2},\cdots ,\tilde{h}_{v}^{t,k})

g代表神经网络,(x,y,\cdots )代表拼接,g在所有节点上共享。这样的混合过程保持了节点排列的不变性,同时允许图的不同副本在传播阶段相互通信。这样的设计提升了计算效率,比如在使用矩阵相乘的消息函数时一个副本的复杂度为O(n^{2}(d/k)^{2}),当有k个副本时一共为O(n^{2}d^{2}/k)

四、输入表示

一个分子有很多特征,如下图所示:

分子特征

边的特征包括化学键与距离,因此有以下三种表示方式:
①化学图(Chemical Graph):在不考虑距离的情况下,邻接矩阵的值是离散的键类型:单键,双键,三键或芳香键;
②距离分桶(Distance bins):GG-NN基于矩阵乘法的消息函数的前提假设是“边信息是离散的”,因此作者将键的距离分为 10 个 bin,比如说 [2,6]中均匀划分 8 个 bin,[0,2]为 1 个 bin,[6, +∞]为 1 个 bin;
③原始距离特征(Raw distance feature):也可以同时考虑距离和化学键的特征,这时每条边都有自己的特征向量,此时邻接矩阵的每个实例都是一个 5 维向量,第一维是距离,其余4维是一个独热向量,代表4种不同的化学键。

五、实验

实验中对比了本文提出的方法与现有的方法:

对比

以下为不考虑空间信息的结果:

不考虑空间信息

以下为一些消融实验:

消融实验

具体实验设置参照原文。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,651评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,468评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,931评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,218评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,234评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,198评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,084评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,926评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,341评论 1 311
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,563评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,731评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,430评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,036评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,676评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,829评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,743评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,629评论 2 354

推荐阅读更多精彩内容

  • 图计算任务成功的关键是为节点生成节点表征,本节主要任务如下: 1.首先学习图神经网络生成节点表征的规范-消息传递范...
    从你美丽的流域阅读 1,060评论 0 0
  • 笔记部分 消息传递范式用表示层中节点的节点特征, 表示从节点到节点的边的特征,消息传递图神经网络可以描述为公式i...
    79f3c66c2fe7阅读 359评论 0 0
  • 参考链接:https://github.com/datawhalechina/team-learning-nlp/...
    沫2021阅读 413评论 0 0
  • 图嵌入综述 图分析任务可以大致抽象的分为以下四类:(a)节点分类(b)链接预测(c)聚类(d)可视化。 真实的图(...
    逃淘桃阅读 5,898评论 0 1
  • 表情是什么,我认为表情就是表现出来的情绪。表情可以传达很多信息。高兴了当然就笑了,难过就哭了。两者是相互影响密不可...
    Persistenc_6aea阅读 124,957评论 2 7