TextRank起源与PageRank
TextRank的灵感来源于大名鼎鼎的PageRank算法,这是一个用作网页重要度排序的算法。
并且,这个算法也是基于图的,每个网页可以看作是一个图中的结点,如果网页A能够跳转到网页B,那么则有一条A->B的有向边。这样,我们就可以构造出一个有向图了。
然后,利用公式:
经过多次迭代就可以获得每个网页对应的权重。下面解释公式每个元素的含义:
: 网页的重要度(权重),初始值可设为1。
: 阻尼系数,一般为0.85。
:能跳转到网页的页面,在图中对应入度对应的点。
:网页能够跳转到的页面,在图中对应出度的点。
可以发现,这个方法只要构造好图,对应关系自然就有了,这实际上是一个比较通用的算法。那么对于文本来说,也是同样的,只要我们能够构造出一个图,图中的结点是单词or句子,只要我们通过某种方法定义这些结点存在某种关系,那么我们就可以使用上面的算法,得到一篇文章中的关键词or摘要。
使用TextRank提取关键词
提取关键词,和网页中选哪个网页比较重要其实是异曲同工的,so,我们只需要想办法把图构建出来就好了。
图的结点其实比较好定义,就是单词喽,把文章拆成句子,每个句子再拆成单词,以单词为结点。
那么边如何定义呢?这里就可以利用n-gram的思路,简单来说,某个单词,只与它附近的n个单词有关,即与它附近的n个词对应的结点连一条无向边(两个有向边)。
另外,还可以做一些操作,比如把某类词性的词删掉,一些自定义词删掉,只保留一部分单词,只有这些词之间能够连边。
下面是论文中给出的例子:
当构图成功以后,就可以利用上面的公式进行迭代求解了。
使用TextRank提取文章摘要
提取关键词以单词为结点,很显然,提取文章摘要自然就是以句子为结点了。那么边呢?如何定义呢?上面的方法似乎不是很适用了,因为两个句子即使相邻,也可以去讲完全不同的两件事。
在论文里,作者给出了一个方法,那就是计算两个句子的相似度。我的理解是这样的,这个计算相似度,其实就是一个比较粗略的方法来判断这两个句子是不是在讲同一个事情,如果两个句子是讲同一个事情,那么肯定会使用相似的单词之类的,这样就可以连一个边了。
既然有了相似度,那么就会有两个句子很相似,两个句子不太相似的情况了,因此,连的边也需要是带权值的边了。
下面是论文中给出的相似度的公式:
:第i个句子。
:第k个单词。
:句子i中单词数。
简单来说就是,两个句子单词的交集除以两个句子的长度(至于为什么用log,没想明白,论文里也没提)。然后还有一点,就是,其他计算相似度的方法应该也是可行的,比如余弦相似度,最长公共子序列之类的,不过论文里一笔带过了。
由于使用了带权的边,因此公式也要进行相应的修改:
上面的公式基本上就是把原来对应边的部分添加了权重,边的数量和改成了权重和,很好理解。
TextRank文章摘要代码实现
我这里只实现了文章摘要的代码,关键词提取也是类似的。要知道,TextRank是不需要训练的,因此,只需要一份测试数据就可以了,数据我是在这里下的:
http://tcci.ccf.org.cn/conference/2018/taskdata.php
上面链接中,Task3就是一个文章摘要的问题,直接下测试集就可以了。
有了测试集,还要有评估方法,这里我用了ROUGE的评价方法,大家有兴趣可以搜一下,这里贴一下简单的代码,如果有误,欢迎指出:
#-*- encoding:utf-8 -*-
import os, sys
class Rouge(object):
"""docstring for Rouge"""
n_gram = 1
def __init__(self, n_gram=2):
super(Rouge, self).__init__()
self.n_gram = n_gram
def get_ngrams(self, text):
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - self.n_gram
for i in xrange(max_index_ngram_start + 1):
ngram_set.add(text[i:i + self.n_gram])
return ngram_set
def get_test_result(self, evaluated_text, reference_text):
evaluated_ngram_set = self.get_ngrams(evaluated_text)
reference_ngram_set = self.get_ngrams(reference_text)
evaluated_count = len(evaluated_ngram_set)
reference_count = len(reference_ngram_set)
overlapping_ngrams = evaluated_ngram_set.intersection(reference_ngram_set)
overlapping_count = len(overlapping_ngrams)
if evaluated_count == 0:
precision = 0.0
else:
precision = overlapping_count * 1.0 / evaluated_count
if reference_count == 0:
recall = 0.0
else:
recall = overlapping_count / (reference_count + 1e-8)
f1_score = 2.0*precision*recall/(precision + recall + 1e-8)
return {"f1_score": f1_score, "precision": precision, "recall": recall}
接下来是TextRank的实现,这里有两个问题,一个是分句,一个是分词。分句的话,我就是用简单的结束符号来分割(。?!)。分词的话,我使用了jieba分词,在github上可以搜一下。另外,由于一个句子中的单词其实并不多的,所以我就用了邻接矩阵来实现(主要是懒),这样写还有个好处就是,上面的公式可以用矩阵乘法实现了,写起来很简洁。
#-*- encoding:utf-8 -*-
import os, sys, math
import numpy as np
import jieba
class SplitWords(object):
"""docstring for SplitWords"""
sentence_end_words = u'。?!'
def __init__(self):
super(SplitWords, self).__init__()
def split_sentence(self, text):
results = []
n = len(text)
now_text = ''
for i in xrange(n):
ch = text[i]
if ch in self.sentence_end_words:
results.append({'sentence': now_text, 'key_words':set()})
now_text = ''
else:
now_text += ch
if now_text != '':
results.append({'sentence': now_text, 'key_words':set()})
return results
def process(self, text):
results = self.split_sentence(text)
for item in results:
seg_list = jieba.cut(item['sentence'], cut_all=False)
for word in seg_list:
item['key_words'].add(word)
return results
class TextRank(object):
"""docstring for TextRank"""
def __init__(self, iter_times=20, d=0.85):
super(TextRank, self).__init__()
self.iter_times = iter_times
self.d = d
self.word_spilter = SplitWords()
def cal_sentence_similarity(self, sentence_item1, sentence_item2):
intersect_item = sentence_item1['key_words'].intersection(sentence_item2['key_words'])
intersect_count = len(intersect_item)
n1 = len(sentence_item1['key_words'])
n2 = len(sentence_item2['key_words'])
# print intersect_count, n1, n2
similarity_value = intersect_count / (math.log(n1 + 1e-8) + math.log(n2 + 1e-8) + 1e-8)
return similarity_value
def sort_sentences(self, splited_words, ver_ws):
temp_list = []
for i, item in enumerate(splited_words):
temp_list.append((ver_ws[i][0], item['sentence']))
temp_list.sort()
return temp_list
def process(self, text):
splited_words = self.word_spilter.process(text)
sentence_length = len(splited_words)
weights = np.zeros((sentence_length, sentence_length), dtype=float)
mul_mat = np.zeros((sentence_length, sentence_length), dtype=float)
for i in range(sentence_length):
for j in range(i + 1, sentence_length):
weights[i][j] = weights[j][i] = self.cal_sentence_similarity(splited_words[i], splited_words[j])
for i in range(sentence_length):
for j in range(sentence_length):
if weights[i][j] == 0:
continue
mul_mat[i][j] = weights[j][i]/(weights[j].sum() + 1e-8)
ver_ws = np.ones((sentence_length, 1), dtype=float)
for i in range(self.iter_times):
ver_ws = (1 - self.d) + self.d * np.dot(mul_mat, ver_ws)
# print ver_ws
return self.sort_sentences(splited_words, ver_ws)
最后是测试代码:
#-*- encoding:utf-8 -*-
import sys
import json
import codecs
from optparse import OptionParser
from textrank4zh import TextRank4Keyword, TextRank4Sentence
from rouge import Rouge
from text_rank import TextRank
def load_data(input_file):
data_list = []
with codecs.open(input_file, 'r', encoding='UTF-8') as f:
for line in f:
line_data = json.loads(line.encode('utf-8'))
# print line_data['summarization']
# print line_data['article']
data_list.append(line_data)
return data_list
def run(options):
rouge_tester = Rouge(2)
data_list = load_data(options.input_file)
tr4s = TextRank4Sentence()
my_text_rank = TextRank()
avg_result = {'f1_score': 0, "precision": 0, "recall": 0}
# total_count = 0
for item in data_list:
article = item['article']
reference_text = item['summarization']
# tr4s.analyze(text=article, lower=True, source = 'all_filters')
evaluated_text = ''
#for item in tr4s.get_key_sentences(num=1):
# # print(item.index, item.weight, item.sentence)
# evaluated_text = item.sentence
rank_results = my_text_rank.process(article)
evaluated_text = rank_results[-1][1]
test_result = rouge_tester.get_test_result(evaluated_text, reference_text)
print evaluated_text, reference_text
print test_result
print avg_result
for key,value in test_result.items():
avg_result[key] += value
n = len(data_list)
print 'n:{}'.format(n)
for key,value in test_result.items():
avg_result[key] /= (n + 1e-8)
print avg_result
def init_options():
parser = OptionParser()
parser.add_option('-i', '--input_file', dest='input_file', default='Downloads/TTNewsCorpus_NLPCC2017/toutiao4nlpcc_eval/evaluation_with_ground_truth.txt', help='the input file')
(options, arges) = parser.parse_args()
print options
return options
if __name__ == '__main__':
options = init_options()
run(options)