基于TextRank的中文文章概要

本文最后更新于:1 个月前

中文文章概要的实现

前言

最近参加华五的比赛,我们队伍打算做一个聚合类的阅读软件,需要用到中文文章概要,选择使用抽取式的TextRank算法

原理

  • 先把所有文章整合为文本数据,并将文本分割成单个句子
  • 将句子中的每个词向量相加取均值,获取句向量
  • 通过计算余弦相似度得到句子间的相似度,得到相似度矩阵
  • 相似度矩阵化为以句子为节点,相似度得分为边的图结构
  • 对句子的得分进行排序,取排名靠前的n个句子为概要

基于Text_Rank的中文文章概要

参考文章:

项目地址:

本项目已在github上开源:github地址

源代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# coding = utf-8
import re
import jieba
import numpy as np
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
import pymysql
import os
import configparser
class TextSummarizer:
def __init__(self, article, num): # 初始化
self.article = article
self.num = num
self.word_embeddings = {}
self.stopwords = None
self.sentences_vectors = []
self.similarity_matrix = None
self.ranked_sentences = None
self.text_str = ""

def __word_embeddings(self): # 获取词向量
for i, line in enumerate(open('res/sgns.sogou.char', encoding = 'utf-8').readlines()):
if i != 0: # 第一行为统计信息,去除
values = line.split()
word = values[0] # 第一个为所表示的词
dimen = np.asarray(values[1:], dtype='float32') # 后面的为维度
self.word_embeddings[word] = dimen

def __stopwords(self): # 获取停用词
self.stopwords = [line.strip() for line in open('res/stopwords.txt', encoding='utf-8').readlines()] # for循环放在后面相对于列表来说处理的更快

def __sentences(self, sentences): # 断句
# 分号、破折号、英文双引号做了忽略
for sentence in sentences:
sentence = re.sub('([(),。!?\?])([^”’])', r'\1\n\2', sentence) # 单字符断句符
sentence = re.sub('(\.{6})([^”’])', r'\1\n\2', sentence) # 英文省略号
sentence = re.sub('(\…{2})([^”’])', r'\1\n\2', sentence) # 中文省略号
sentence = re.sub('([。!?\?][”’])([^,。!?\?])', r'\1\n\2', sentence) # \n放到双引号后
sentence = sentence.rstrip() # 去掉末尾多余的\n
sentence.split("\n")
# print(sentences[:5])
return sentences

def __remove_stopwords(self, sentence): # 去除停用词
sentence = [i for i in sentence if i not in self.stopwords]
return sentence

def __sentence_vectors(self, cleaned_sentences): # 获取句向量,将句子中的每个词向量相加取均值
for i in cleaned_sentences:
if len(i) != 0:
ave = sum([self.word_embeddings.get(j.strip(), np.zeros((300,))) for j in i]) / (len(i) + 1e-2) # 预训练的词向量维度为300
# np.zeros返回来一个给定形状和类型的用0填充的数组;
# zeros(shape, dtype=float, order='C')
# shape:形状
# dtype:数据类型,可选参数,默认numpy.float64
# order:可选参数,C代表行优先;F代表列优先
else:
ave = np.zeros((300,)) # 预训练的词向量维度为300
self.sentences_vectors.append(ave)

def __similarity_matrix(self): # 基于余弦相似度计算相似度矩阵
self.similarity_matrix = np.zeros((len(self.sentences_vectors), len(self.sentences_vectors))) # 以句向量列表的长度构建方阵
for i in range(len(self.sentences_vectors)):
for j in range(len(self.sentences_vectors)):
if i != j:
self.similarity_matrix[i][j] = cosine_similarity(self.sentences_vectors[i].reshape(1, -1), self.sentences_vectors[j].reshape(1, -1)) # 计算两个矩阵的余弦相似度
# reshape(1, -1)指的是变成一行,-1表示不知道分多少列

def generate_summary(self):
self.__word_embeddings() # 获取词向量
self.__stopwords() # 获取停用词
sentences = self.__sentences(self.article) # 将文章分割成句子
cutted_sentences = [jieba.lcut(s) for s in sentences] # 对每个句子分词
cleaned_sentences = [self.__remove_stopwords(sentence) for sentence in cutted_sentences] # 去除停用词
self.__sentence_vectors(cleaned_sentences) # 获取句向量
self.__similarity_matrix() # 获取相似度矩阵
nx_graph = nx.from_numpy_array(self.similarity_matrix) # 将相似度矩阵转换为图的结构
scores = nx.pagerank(nx_graph) # 获得句子间的相关度分数
self.ranked_sentences = sorted(((scores[i], s) for i, s in enumerate(sentences)), reverse = True) # 根据得分进行降序排序
for i in range(self.num): # 获取得分前几的句子
self.text_str += self.ranked_sentences[i][1]
# print(self.ranked_sentences[i][1])
# print(self.text_str)
return self.text_str

class ReadConfig:
def __init__(self):
mysql_config_path = os.path.join('', 'mysql_config.ini')
self.cf = configparser.ConfigParser()
self.cf.read(mysql_config_path, encoding = 'utf-8')

def __mysql_read(self, param):
val = self.cf.get('mysql', param)
return val

def mysql_config(self):
mysql = pymysql.connect( # 连接数据库
host = self.__mysql_read('host'),
port = int(self.__mysql_read('port')),
user = self.__mysql_read('user'),
passwd = self.__mysql_read('password'),
db = self.__mysql_read('datebase'),
charset = self.__mysql_read('charset'))
cur = mysql.cursor() # 创建游标
try:
sql = 'select id, content from article where summary is null' # 编写sql
cur.execute(sql) # 执行sql
D = cur.fetchall() # 查看结果
print(D)
for d in D:
# print(d[1])
summary = TextSummarizer(d[1].split("\n"), 3).generate_summary()
# print(summary + "\n")
sql = 'update article set summary = %s where id = %s'
cur.execute(sql, (summary, d[0]))
mysql.commit()
mysql.close()
except IOError as msg:
print("ERROR ! ! !")
print(str(msg))
mysql.rollback()
mysql.close()

ReadConfig().mysql_config()

Tips


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!