<span style="font-size:16px">
</span><span style="font-size:16px"><div class="image-package"><img src="https://upload-images.jianshu.io/upload_images/26011021-4ae807cb0dd8cf89.jpeg" img-data="{"format":"jpeg","size":39554,"height":449,"width":1440}" class="uploaded-img" style="min-height:200px;min-width:200px;" width="auto" height="auto"/>
</div><p>本文先简单概述GNN链接预测任务,接下来使用Deep Graph Library实现GNN进行链接预测,并对代码进行详细介绍,若需获取模型的完整代码,可关注公众号【AI机器学习与知识图谱】后回复:<strong>DGL第二讲完整代码</strong>
</p><p><strong>
</strong></p></span><strong><font size="4">一、GNN链接预测概述</font></strong><div><font size="5"><b>
</b></font><span style="font-size:16px"><span>GNN链接预测任务,即预测图中两个节点之间的边是否存在。在</span></span><span style="font-size:16px">Social Recommendation,</span><span style="font-size:16px">Knowledge G</span><span style="font-size:16px"><span>raph Completion等应用中都需要进行链接预测。模型实现上是将链接预测任务看成一个二分类任务:</span></span><span style="font-size:16px">
</span><span style="font-size:16px">1. 将图中存在的边作为正样本;</span><span style="font-size:16px">2. 负采样一些图中不存在的边作为负样本;</span><span style="font-size:16px">3. 将正样例和负样例合并划分为训练集和测试集;</span><span style="font-size:16px">4. 可以采用二分类模型的评估指标来评估模型的效果,例如:AUC值</span><span style="font-size:16px">
</span><span style="font-size:16px"><span>在一些场景下例如大规模推荐系统或信息检索,模型需要评估top</span><span>-k</span><span>预测结果的准确性,因此对于链接预测任务还需要一些其他的评估指标来衡量模型最终效果:</span></span><span style="font-size:16px">1. MR(MeanRank)</span><span style="font-size:16px">2. MRR(Mean Reciprocal Rank)</span><span style="font-size:16px">3. Hit@n</span><span style="font-size:16px">
</span><span style="font-size:16px"><span>MR, MRR, Hit@n</span><span>指标含义:假设整个图谱中共n个实体,评估前先进行如下操作:</span></span><span style="font-size:16px"><span>(1)</span><span>将一个正确的三元组</span><span><span>中的头实体</span><span>h</span><span>或者尾实体t,依次替换成整个图谱中的其他所有实体,这样会产生n个三元组;</span></span></span><span style="font-size:16px"><span>(2)</span><span>对(</span><span>1</span><span>)中产生的n个三元组分别计算其能量值,例如在TransE中计算</span><span><span>的值,这样n个三元组分别对应自己的能量值;</span></span></span><span style="font-size:16px"><span>(3)</span><span>对上述n个三元组按照能量值进行升序排序,记录每个三元组排序后的序号;</span></span><span style="font-size:16px"><span>(4)</span><span>对所有正确的三元组都进行上述三步操作</span></span><span style="font-size:16px">
</span><p style="text-indent:0pt"><strong><span style="font-size:16px">MR指标:</span></strong><span style="font-size:16px">将整个图谱中每个正确三元组的能量值排序后的序号取平均得到的值;</span></p><p style="text-indent:0pt"><strong><span style="font-size:16px">MRR指标:</span></strong><span style="font-size:16px">将整个图谱每个正确三元组的能量排序后的序号倒数取平均得到的值;</span></p><strong><span style="font-size:16px"><span>Hit@n</span><span>指标:</span></span></strong><span style="font-size:16px"><span>整个图谱正确三元组的能量排序后序号小于n的三元组所占的比例。</span></span><p style="text-indent:0pt"><span style="font-size:16px">
</span></p><p style="text-indent:0pt"><span style="font-size:16px"><span>因此对于链接预测任务来说,MR指标越小,模型效果越好;MRR和Hit</span><span>@n</span><span>指标越大,模型效果越好。</span></span><span style="font-size:16px">接下来本文将在Cora引文数据集上,预测两篇论文之间是否存在引用关系或被引用关系。</span></p>
<span style="font-size:16px"><span style="font-size:20px"><strong>二、GNN链接预测实现</strong></span></span><p style="text-indent:0pt"><span style="font-size:16px"><span><span style="font-size:16px">接下来使用DGL框架实现GNN模型进行链接任务,对代码给出详细解释。<span style="font-size:16px">首先如下所示,先加载需要使用的dgl库和pytorch库;</span></span></span></span></p><span>import</span> dgl
<span>import</span> torch
<span>import</span> torch.nn <span>as</span> nn
<span>import</span> torch.nn.functional <span>as</span> F
<span>import</span> itertools
<span>import</span> numpy <span>as</span> np
<span>import</span> scipy.sparse <span>as</span> sp<p style="text-indent:0pt"><strong>数据加载</strong><span>:下面代码加载dgl库提供的Cora数据对象,dgl库中Dataset数据集可能是包含多个图的,所以加载的dataset对象是一个list,list中的每个元素对应该数据的一个graph,但Cora数据集是由单个图组成,因此直接使用dataset[0]取出graph。</span>
</p><span>import</span> dgl.data
dataset = dgl.data.CoraGraphDataset()
g = dataset[<span>0</span>]<strong><span style="font-size:16px">
</span></strong><strong><span style="font-size:16px">正负数据划分</span></strong><span style="font-size:16px">:随机抽取数据集中10%的边作为测试集中的正样例,剩下的90%数据集中的边作为训练集,然后随机为训练集和测试集,负采样生成相同数量的负样例,使得训练集和测试集中的正负样本比例为1:1,将数据集中边的集合划分到训练集和测试集中,训练集90%,测试集10%</span>
u, v = g.edges()
eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * <span>0.1</span>)
train_size = g.number_of_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]
<span># 采样所有负样例并划分为训练集和测试集中。</span>
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = <span>1</span> - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != <span>0</span>)
neg_eids = np.random.choice(len(neg_u), g.number_of_edges() // <span>2</span>)
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]<p style="text-indent:0pt"><strong><span style="font-size:16px">注意:</span></strong><span style="font-size:16px">在模型训练时,需要将图中在测试集中10%的边移除掉,防止数据泄露,使用dgl.remove_edges</span></p>train_g = dgl.remove_edges(g, eids[:test_size])
<span>from</span> dgl.nn <span>import</span> SAGEConv
<span># 定义一个两层的GraphSage模型</span>
<span><span>class</span> <span>GraphSAGE</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>init</span><span>(self, in_feats, h_feats)</span>:</span>
super(GraphSAGE, self).init()
self.conv1 = SAGEConv(in_feats, h_feats, <span>'mean'</span>)
self.conv2 = SAGEConv(h_feats, h_feats, <span>'mean'</span>)
<span><span>def</span> <span>forward</span><span>(self, g, in_feat)</span>:</span>
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
<span>return</span> h<p><span style="font-size:16px"><strong>
</strong></span></p><p><span style="font-size:16px"><strong>节点对得分函数定义方式</strong>:模型通过定义函数来预测两个节点表示之间的得分,从而来判断两个节点之间存在边的可能性,在GNN节点分类任务重,模型是训练得到单个节点的表征,但在链接计算任务中是预测节点对的表征</span>
</p><p style="text-indent:0pt">
</p><p style="text-indent:0pt"><strong><span style="font-size:16px">注意:</span></strong><span style="font-size:16px">在给出节点对的预测得分函数之前,先需要理解一下:DGL使用方式是先将节点对视为一个图,同时一条边可用来描述一对节点。在链接预测中,会得到一个正图,它包含所有的正例子作为边,以及一个负图,它包含所有的负例子。正图和负图将包含与原始图相同的节点集。这使得在多个图中传递节点特征更容易进行计算。可以直接将在整个图上计算的节点表示形式提供给正图和负图,用于计算节点对的两两得分。</span></p><p style="text-indent:0pt">
</p><p style="text-indent:0pt"><span style="font-size:16px">将节点对视为图的好处是可以使用DGLGraph.apply_edges方法,基于节点的特征表示和原始图中边的特征表示可以方便地计算新产生边的特征表示。DGL提供了一组优化的内置函数,可以直接使用原始节点/边特征表示计算新的边特征表示。</span>
</p>train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())
test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())<p><span style="font-size:16px"><strong>
</strong></span></p><p><span style="font-size:16px"><strong>官方和自定义两种得分函数实现</strong>:接下来定义两个节点之间的得分函数预测,可以直接使用DGL提供的,也可以自定义,下面DotPredictor是官方提供的预测函数,MLPPredictor是自定义的预测函数。</span></p><span>import</span> dgl.function <span>as</span> fn
<span><span>class</span> <span>DotPredictor</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>forward</span><span>(self, g, h)</span>:</span>
<span>with</span> g.local_scope():
g.ndata[<span>'h'</span>] = h
<span># 通过源节点特征“h”和目标节点特征“h”之间的点积计算两点之间存在边的Score</span>
g.apply_edges(fn.u_dot_v(<span>'h'</span>, <span>'h'</span>, <span>'score'</span>))
<span># u_dot_v为每条边返回一个元素向量,因此需要squeeze操作</span>
<span>return</span> g.edata[<span>'score'</span>][:, <span>0</span>]<p>
</p><span><span>class</span> <span>MLPPredictor</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>init</span><span>(self, h_feats)</span>:</span>
super().init()
self.W1 = nn.Linear(h_feats * <span>2</span>, h_feats)
self.W2 = nn.Linear(h_feats, <span>1</span>)
<span><span>def</span> <span>apply_edges</span><span>(self, edges)</span>:</span>
<span>"""
Computes a scalar score for each edge of the given graph.</span><span>
Parameters
----------
edges :
Has three members src
, dst
and data
, each of
which is a dictionary representing the features of the
source nodes, the destination nodes, and the edges
themselves.
Returns
-------
dict
A dictionary of new edge features.
"""</span>
h = torch.cat([edges.src[<span>'h'</span>], edges.dst[<span>'h'</span>]], <span>1</span>)
<span>return</span> {<span>'score'</span>: self.W2(F.relu(self.W1(h))).squeeze(<span>1</span>)}
<span><span>def</span> <span>forward</span><span>(self, g, h)</span>:</span>
<span>with</span> g.local_scope():
g.ndata[<span>'h'</span>] = h
g.apply_edges(self.apply_edges)
<span>return</span> g.edata[<span>'score'</span>]<p><span style="font-size:16px">接下来直接进行模型训练:</span><span style="font-size:16px"/></p>optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=<span>0.01</span>)
all_logits = []
<span>for</span> e <span>in</span> range(<span>100</span>):
<span># 前向传播</span>
h = model(train_g, train_g.ndata[<span>'feat'</span>])
pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h)
loss = compute_loss(pos_score, neg_score)
<span># 后向传播</span>
optimizer.zero_grad()
loss.backward()
optimizer.step()
<span>if</span> e % <span>5</span> == <span>0</span>:
print(<span>'In epoch {}, loss: {}'</span>.format(e, loss))
<span># 检测结果准确性</span>
<span>from</span> sklearn.metrics <span>import</span> roc_auc_score
<span>with</span> torch.no_grad():
pos_score = pred(test_pos_g, h)
neg_score = pred(test_neg_g, h)
print(<span>'AUC'</span>, compute_auc(pos_score, neg_score))<p><span style="font-size:16px">
</span></p><p><span style="font-size:16px">上面是模型的训练函数,和pytorch模型训练过程都是相似的,训练过程如下图所示:</span></p>In epoch <span>0</span>, loss: <span>0.6172636151313782</span>
In epoch <span>5</span>, loss: <span>0.6101921796798706</span>
In epoch <span>10</span>, loss: <span>0.5864554047584534</span>
In epoch <span>15</span>, loss: <span>0.5405876040458679</span>
In epoch <span>20</span>, loss: <span>0.4583510458469391</span>
In epoch <span>25</span>, loss: <span>0.39045605063438416</span>
In epoch <span>30</span>, loss: <span>0.34702828526496887</span>
In epoch <span>35</span>, loss: <span>0.3122958838939667</span>
In epoch <span>40</span>, loss: <span>0.2834944725036621</span>
In epoch <span>45</span>, loss: <span>0.25488677620887756</span>
In epoch <span>50</span>, loss: <span>0.22920763492584229</span>
In epoch <span>55</span>, loss: <span>0.20638766884803772</span>
In epoch <span>60</span>, loss: <span>0.18289318680763245</span>
In epoch <span>65</span>, loss: <span>0.16009262204170227</span>
In epoch <span>70</span>, loss: <span>0.1381770521402359</span>
In epoch <span>75</span>, loss: <span>0.11725720018148422</span>
In epoch <span>80</span>, loss: <span>0.09779688715934753</span>
In epoch <span>85</span>, loss: <span>0.07947927713394165</span>
In epoch <span>90</span>, loss: <span>0.06309689581394196</span>
In epoch <span>95</span>, loss: <span>0.048749890178442</span>
AUC <span>0.8526520069180836</span><p>
</p><p><span style="font-size:18px"><strong>往期精彩</strong></span></p><p>【知识图谱系列】基于生成式的知识图谱预训练模型</p><p>【知识图谱系列】基于实数或复数空间的知识图谱嵌入
</p><p><span style="font-size:14px"/></p><p>【知识图谱系列】知识图谱多跳推理之强化学习
</p><p><span style="font-size:14px"/></p><p>【知识图谱系列】动态时序知识图谱EvolveGCN</p><p><span style="font-size:14px"/></p><p>【机器学习系列】机器学习中的两大学派</p></div>
「GNN框架系列」DGL第二讲:实现GNN链接预测
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 本文先简单概述GNN节点分类任务,然后详细介绍如何使用Deep Graph Library + Pytorch...
- 数据完整存储与内存的数据集类 一、InMemoryDataset基类简介 在PyG中,通过继承InMemoryDa...
- 姓名:冯子豪 学号:16020199001 转载自https://zhuanlan.zhihu.com/p/248...
- 预备知识: 会使用pytorch搭建简单的cnn 熟悉神经网络的训练过程与优化方法 结合理论课的内容,了解目标检测...