「GNN框架系列」DGL第二讲:实现GNN链接预测

<span style="font-size:16px">
</span><span style="font-size:16px"><div class="image-package"><img src="https://upload-images.jianshu.io/upload_images/26011021-4ae807cb0dd8cf89.jpeg" img-data="{"format":"jpeg","size":39554,"height":449,"width":1440}" class="uploaded-img" style="min-height:200px;min-width:200px;" width="auto" height="auto"/>
</div><p>本文先简单概述GNN链接预测任务,接下来使用Deep Graph Library实现GNN进行链接预测,并对代码进行详细介绍,若需获取模型的完整代码,可关注公众号【AI机器学习与知识图谱】后回复:<strong>DGL第二讲完整代码</strong>
</p><p><strong>
</strong></p></span><strong><font size="4">一、GNN链接预测概述</font></strong><div><font size="5"><b>
</b></font><span style="font-size:16px"><span>GNN链接预测任务,即预测图中两个节点之间的边是否存在。在</span></span><span style="font-size:16px">Social Recommendation,</span><span style="font-size:16px">Knowledge G</span><span style="font-size:16px"><span>raph Completion等应用中都需要进行链接预测。模型实现上是将链接预测任务看成一个二分类任务:</span></span><span style="font-size:16px">
</span><span style="font-size:16px">1. 将图中存在的边作为正样本;</span><span style="font-size:16px">2. 负采样一些图中不存在的边作为负样本;</span><span style="font-size:16px">3. 将正样例和负样例合并划分为训练集和测试集;</span><span style="font-size:16px">4. 可以采用二分类模型的评估指标来评估模型的效果,例如:AUC值</span><span style="font-size:16px">
</span><span style="font-size:16px"><span>在一些场景下例如大规模推荐系统或信息检索,模型需要评估top</span><span>-k</span><span>预测结果的准确性,因此对于链接预测任务还需要一些其他的评估指标来衡量模型最终效果:</span></span><span style="font-size:16px">1. MR(MeanRank)</span><span style="font-size:16px">2. MRR(Mean Reciprocal Rank)</span><span style="font-size:16px">3. Hit@n</span><span style="font-size:16px">
</span><span style="font-size:16px"><span>MR, MRR, Hit@n</span><span>指标含义:假设整个图谱中共n个实体,评估前先进行如下操作:</span></span><span style="font-size:16px"><span>(1)</span><span>将一个正确的三元组</span><span><span>中的头实体</span><span>h</span><span>或者尾实体t,依次替换成整个图谱中的其他所有实体,这样会产生n个三元组;</span></span></span><span style="font-size:16px"><span>(2)</span><span>对(</span><span>1</span><span>)中产生的n个三元组分别计算其能量值,例如在TransE中计算</span><span><span>的值,这样n个三元组分别对应自己的能量值;</span></span></span><span style="font-size:16px"><span>(3)</span><span>对上述n个三元组按照能量值进行升序排序,记录每个三元组排序后的序号;</span></span><span style="font-size:16px"><span>(4)</span><span>对所有正确的三元组都进行上述三步操作</span></span><span style="font-size:16px">
</span><p style="text-indent:0pt"><strong><span style="font-size:16px">MR指标:</span></strong><span style="font-size:16px">将整个图谱中每个正确三元组的能量值排序后的序号取平均得到的值;</span></p><p style="text-indent:0pt"><strong><span style="font-size:16px">MRR指标:</span></strong><span style="font-size:16px">将整个图谱每个正确三元组的能量排序后的序号倒数取平均得到的值;</span></p><strong><span style="font-size:16px"><span>Hit@n</span><span>指标:</span></span></strong><span style="font-size:16px"><span>整个图谱正确三元组的能量排序后序号小于n的三元组所占的比例。</span></span><p style="text-indent:0pt"><span style="font-size:16px">
</span></p><p style="text-indent:0pt"><span style="font-size:16px"><span>因此对于链接预测任务来说,MR指标越小,模型效果越好;MRR和Hit</span><span>@n</span><span>指标越大,模型效果越好。</span></span><span style="font-size:16px">接下来本文将在Cora引文数据集上,预测两篇论文之间是否存在引用关系或被引用关系。</span></p>
<span style="font-size:16px"><span style="font-size:20px"><strong>二、GNN链接预测实现</strong></span></span><p style="text-indent:0pt"><span style="font-size:16px"><span><span style="font-size:16px">接下来使用DGL框架实现GNN模型进行链接任务,对代码给出详细解释。<span style="font-size:16px">首先如下所示,先加载需要使用的dgl库和pytorch库;</span></span></span></span></p><span>import</span> dgl
<span>import</span> torch
<span>import</span> torch.nn <span>as</span> nn
<span>import</span> torch.nn.functional <span>as</span> F
<span>import</span> itertools
<span>import</span> numpy <span>as</span> np
<span>import</span> scipy.sparse <span>as</span> sp<p style="text-indent:0pt"><strong>数据加载</strong><span>:下面代码加载dgl库提供的Cora数据对象,dgl库中Dataset数据集可能是包含多个图的,所以加载的dataset对象是一个list,list中的每个元素对应该数据的一个graph,但Cora数据集是由单个图组成,因此直接使用dataset[0]取出graph。</span>
</p><span>import</span> dgl.data
dataset = dgl.data.CoraGraphDataset()
g = dataset[<span>0</span>]<strong><span style="font-size:16px">
</span></strong><strong><span style="font-size:16px">正负数据划分</span></strong><span style="font-size:16px">:随机抽取数据集中10%的边作为测试集中的正样例,剩下的90%数据集中的边作为训练集,然后随机为训练集和测试集,负采样生成相同数量的负样例,使得训练集和测试集中的正负样本比例为1:1,将数据集中边的集合划分到训练集和测试集中,训练集90%,测试集10%</span>
u, v = g.edges()
eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * <span>0.1</span>)
train_size = g.number_of_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

<span># 采样所有负样例并划分为训练集和测试集中。</span>
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = <span>1</span> - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != <span>0</span>)

neg_eids = np.random.choice(len(neg_u), g.number_of_edges() // <span>2</span>)
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]<p style="text-indent:0pt"><strong><span style="font-size:16px">注意:</span></strong><span style="font-size:16px">在模型训练时,需要将图中在测试集中10%的边移除掉,防止数据泄露,使用dgl.remove_edges</span></p>train_g = dgl.remove_edges(g, eids[:test_size])
<span>from</span> dgl.nn <span>import</span> SAGEConv

<span># 定义一个两层的GraphSage模型</span>
<span><span>class</span> <span>GraphSAGE</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>init</span><span>(self, in_feats, h_feats)</span>:</span>
super(GraphSAGE, self).init()
self.conv1 = SAGEConv(in_feats, h_feats, <span>'mean'</span>)
self.conv2 = SAGEConv(h_feats, h_feats, <span>'mean'</span>)

<span><span>def</span> <span>forward</span><span>(self, g, in_feat)</span>:</span>
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
<span>return</span> h<p><span style="font-size:16px"><strong>
</strong></span></p><p><span style="font-size:16px"><strong>节点对得分函数定义方式</strong>:模型通过定义函数来预测两个节点表示之间的得分,从而来判断两个节点之间存在边的可能性,在GNN节点分类任务重,模型是训练得到单个节点的表征,但在链接计算任务中是预测节点对的表征</span>
</p><p style="text-indent:0pt">
</p><p style="text-indent:0pt"><strong><span style="font-size:16px">注意:</span></strong><span style="font-size:16px">在给出节点对的预测得分函数之前,先需要理解一下:DGL使用方式是先将节点对视为一个图,同时一条边可用来描述一对节点。在链接预测中,会得到一个正图,它包含所有的正例子作为边,以及一个负图,它包含所有的负例子。正图和负图将包含与原始图相同的节点集。这使得在多个图中传递节点特征更容易进行计算。可以直接将在整个图上计算的节点表示形式提供给正图和负图,用于计算节点对的两两得分。</span></p><p style="text-indent:0pt">
</p><p style="text-indent:0pt"><span style="font-size:16px">将节点对视为图的好处是可以使用DGLGraph.apply_edges方法,基于节点的特征表示和原始图中边的特征表示可以方便地计算新产生边的特征表示。DGL提供了一组优化的内置函数,可以直接使用原始节点/边特征表示计算新的边特征表示。</span>
</p>train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())<p><span style="font-size:16px"><strong>
</strong></span></p><p><span style="font-size:16px"><strong>官方和自定义两种得分函数实现</strong>:接下来定义两个节点之间的得分函数预测,可以直接使用DGL提供的,也可以自定义,下面DotPredictor是官方提供的预测函数,MLPPredictor是自定义的预测函数。</span></p><span>import</span> dgl.function <span>as</span> fn

<span><span>class</span> <span>DotPredictor</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>forward</span><span>(self, g, h)</span>:</span>
<span>with</span> g.local_scope():
g.ndata[<span>'h'</span>] = h
<span># 通过源节点特征“h”和目标节点特征“h”之间的点积计算两点之间存在边的Score</span>
g.apply_edges(fn.u_dot_v(<span>'h'</span>, <span>'h'</span>, <span>'score'</span>))
<span># u_dot_v为每条边返回一个元素向量,因此需要squeeze操作</span>
<span>return</span> g.edata[<span>'score'</span>][:, <span>0</span>]<p>
</p><span><span>class</span> <span>MLPPredictor</span><span>(nn.Module)</span>:</span>
<span><span>def</span> <span>init</span><span>(self, h_feats)</span>:</span>
super().init()
self.W1 = nn.Linear(h_feats * <span>2</span>, h_feats)
self.W2 = nn.Linear(h_feats, <span>1</span>)

<span><span>def</span> <span>apply_edges</span><span>(self, edges)</span>:</span>
<span>"""
Computes a scalar score for each edge of the given graph.</span><span>
Parameters
----------
edges :
Has three members src, dst and data, each of
which is a dictionary representing the features of the
source nodes, the destination nodes, and the edges
themselves.

Returns
-------
dict
A dictionary of new edge features.
"""</span>
h = torch.cat([edges.src[<span>'h'</span>], edges.dst[<span>'h'</span>]], <span>1</span>)
<span>return</span> {<span>'score'</span>: self.W2(F.relu(self.W1(h))).squeeze(<span>1</span>)}

<span><span>def</span> <span>forward</span><span>(self, g, h)</span>:</span>
<span>with</span> g.local_scope():
g.ndata[<span>'h'</span>] = h
g.apply_edges(self.apply_edges)
<span>return</span> g.edata[<span>'score'</span>]<p><span style="font-size:16px">接下来直接进行模型训练:</span><span style="font-size:16px"/></p>optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=<span>0.01</span>)

all_logits = []
<span>for</span> e <span>in</span> range(<span>100</span>):
<span># 前向传播</span>
h = model(train_g, train_g.ndata[<span>'feat'</span>])
pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h)
loss = compute_loss(pos_score, neg_score)

<span># 后向传播</span>
optimizer.zero_grad()
loss.backward()
optimizer.step()

<span>if</span> e % <span>5</span> == <span>0</span>:
print(<span>'In epoch {}, loss: {}'</span>.format(e, loss))

<span># 检测结果准确性</span>
<span>from</span> sklearn.metrics <span>import</span> roc_auc_score
<span>with</span> torch.no_grad():
pos_score = pred(test_pos_g, h)
neg_score = pred(test_neg_g, h)
print(<span>'AUC'</span>, compute_auc(pos_score, neg_score))<p><span style="font-size:16px">
</span></p><p><span style="font-size:16px">上面是模型的训练函数,和pytorch模型训练过程都是相似的,训练过程如下图所示:</span></p>In epoch <span>0</span>, loss: <span>0.6172636151313782</span>
In epoch <span>5</span>, loss: <span>0.6101921796798706</span>
In epoch <span>10</span>, loss: <span>0.5864554047584534</span>
In epoch <span>15</span>, loss: <span>0.5405876040458679</span>
In epoch <span>20</span>, loss: <span>0.4583510458469391</span>
In epoch <span>25</span>, loss: <span>0.39045605063438416</span>
In epoch <span>30</span>, loss: <span>0.34702828526496887</span>
In epoch <span>35</span>, loss: <span>0.3122958838939667</span>
In epoch <span>40</span>, loss: <span>0.2834944725036621</span>
In epoch <span>45</span>, loss: <span>0.25488677620887756</span>
In epoch <span>50</span>, loss: <span>0.22920763492584229</span>
In epoch <span>55</span>, loss: <span>0.20638766884803772</span>
In epoch <span>60</span>, loss: <span>0.18289318680763245</span>
In epoch <span>65</span>, loss: <span>0.16009262204170227</span>
In epoch <span>70</span>, loss: <span>0.1381770521402359</span>
In epoch <span>75</span>, loss: <span>0.11725720018148422</span>
In epoch <span>80</span>, loss: <span>0.09779688715934753</span>
In epoch <span>85</span>, loss: <span>0.07947927713394165</span>
In epoch <span>90</span>, loss: <span>0.06309689581394196</span>
In epoch <span>95</span>, loss: <span>0.048749890178442</span>
AUC <span>0.8526520069180836</span><p>
</p><p><span style="font-size:18px"><strong>往期精彩</strong></span></p><p>【知识图谱系列】基于生成式的知识图谱预训练模型</p><p>【知识图谱系列】基于实数或复数空间的知识图谱嵌入
</p><p><span style="font-size:14px"/></p><p>【知识图谱系列】知识图谱多跳推理之强化学习
</p><p><span style="font-size:14px"/></p><p>【知识图谱系列】动态时序知识图谱EvolveGCN</p><p><span style="font-size:14px"/></p><p>【机器学习系列】机器学习中的两大学派</p></div>

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,451评论 6 506
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,172评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,782评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,709评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,733评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,578评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,320评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,241评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,686评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,878评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,992评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,715评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,336评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,912评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,040评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,173评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,947评论 2 355

推荐阅读更多精彩内容