-
Notifications
You must be signed in to change notification settings - Fork 2
/
Util.py
281 lines (237 loc) · 8.9 KB
/
Util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import os
import re
from time import time
import random
import Constants
import sys
import torch
import json
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_weights(word2index, path, myembed_path=None):
# weights = [word for word, idx in word2index.items()]
index2word = [word for word, idx in word2index.items()]
weights = [[-1]] * len(index2word)
print("加载预训练词向量", os.path.abspath(path))
# doc = open(path, "r", encoding="utf-8").read().splitlines()
f = open(path, "r", encoding="utf-8")
embeds = []
embed_dim = 300
for line in f:
if len(line) < 300:
print("预训练词向量规模", line)
continue
if line[1] != ' ' or line[0] not in word2index: # 词 生字
continue
tokens = line.strip().split(" ")
if tokens[0] not in word2index:
continue
index = word2index[tokens[0]]
tokens = [float(token) for token in tokens[1:]]
weights[index] = np.asarray(tokens)
embeds.append(line)
if myembed_path:
myembed = open(myembed_path, "w", encoding="utf-8")
myembed.writelines(embeds)
myembed.close()
print("myembed 写入 ", myembed_path)
strange_words = []
for i in range(len(weights)):
if len(weights[i]) == embed_dim:
continue
else:
weights[i] = np.random.randn(embed_dim)
strange_words.append(index2word[i])
print(len(strange_words), "个生词随机初始化", " ".join(strange_words))
return torch.Tensor(weights)
def count_word(counter, word, n=1): # 统计词频 累加n
if word not in counter:
counter[word] = n
else:
counter[word] += n
def sort_counter(counter, reverse=True): # 词频降序
items = sorted(counter.items(), key=lambda kv: kv[1], reverse=reverse)
counter = dict(items)
return counter
def counter2frequency(counter):
sum = 0
for word, num in counter.items():
sum += num
frequency = {}
for word, num in counter.items():
frequency[word] = num / sum
return frequency
def counter2dict(counter, word2index=Constants.Default_Dict, min_freq=2, max_token=10000): # 生成字典
ignored_word_count = 0
for word, count in counter.items():
if len(word2index) >= max_token:
print("词典已满")
break
if word not in word2index:
if count >= min_freq:
word2index[word] = len(word2index)
else:
ignored_word_count += 1
print('[Info] 频繁字典大小 = {},'.format(len(word2index)), '最低频数 = {}'.format(min_freq))
print("[Info] 忽略罕词数 = {}".format(ignored_word_count))
index2word = [k for k, v in word2index.items()]
assert len(index2word) == len(word2index)
return word2index, index2word
def get_index2word(word2index):
index2word = []
for word, count in word2index.items():
index2word.append(word)
return index2word
def sentence2indices(line, word2index, max_len=None, padding_index=None, unk=None, began=None, end=None):
result = [word2index.get(word, unk) for word in line if word in word2index ]
if max_len is not None:
result = result[:max_len]
if began is not None:
result.insert(0, began)
if end is not None:
result.append(end)
if padding_index is not None and len(result) < max_len:
result += [padding_index] * (max_len - len(result))
if not result:
a=0
# assert len(result) == max_len
return result
def indices2sentence(index2word, indices):
sentence = "".join(index2word[index] for index in indices)
return sentence
def split_train(x, rate=0.90, shuffle=True):
if shuffle:
random.shuffle(x)
index = int(len(x) * rate)
train = x[:index]
test = x[index:]
index = int(len(test) * 0.9)
valid = test[:index]
test = test[index:]
return train, valid, test
def write_splits(x, dir="data", shuffle=True):
if shuffle:
random.shuffle(x)
left = int(len(x) * 0.9)
right = left + int(0.9 * (len(x) - left))
with open(dir + "/train.txt", "w", encoding="utf-8") as f:
f.write("\n".join(x[:left]))
with open(dir + "/valid.txt", "w", encoding="utf-8") as f:
f.write("\n".join(x[left:right]))
with open(dir + "/test.txt", "w", encoding="utf-8") as f:
f.write("\n".join(x[right:]))
print("训练集、验证集、测试集已写入", dir, "目录下")
def count_doc(doc, counter={}):
# for index, line in enumerate(iter(lambda: read_start_of_line(reade_file), '')):
for line in doc:
# words = split_lans(line)
words = list(line)
for word in words:
if word:
count_word(counter, word)
return sort_counter(counter)
def merge_counter(counter1, counter2):
if len(counter1) > 0:
for word, num in counter1.items():
count_word(counter2, word, num)
return sort_counter(counter2)
def make_batches(list, batch_size, vocab, max_len=20, shuffle=True):
for i in range(0, len(list), batch_size):
batch = list[i:i + batch_size]
if shuffle:
random.shuffle(batch)
x, y = [], []
for j in range(len(batch)):
batch[j] = batch[j][:max_len]
x.append(sentence2indices(line=batch[j], word2index=vocab, max_len=20, padding_index=Constants.PAD))
y.append(batch[j][1])
yield torch.LongTensor(x), torch.LongTensor(y)
def doc2tensor(examples, word2index, max_len):
data = []
for line in examples:
example = sentence2indices(line, word2index=word2index, max_len=max_len, padding_index=Constants.PAD)
data.append(example)
return torch.LongTensor(data)
def few_data(dict_data, n_class, n_support, n_batch, max_len, word2index=None, index2word=None): # [y][x,x,x,]
classes = random.sample(list(dict_data.keys()), n_class)
labels = np.array(range(n_class))
labels = dict(zip(classes, labels))
support_x, support_y = [], []
batch_x, batch_y = [], []
for c in classes:
# examples = dict_data[c][:n_support + n_batch]
examples = dict_data[c]
# while n_support + n_batch>len(examples):
# examples+=examples
examples = random.sample(examples, n_support + n_batch)
support_x += examples[:n_support]
support_y += [labels[c]] * n_support
batch_x += examples[n_support:n_support + n_batch]
batch_y += [labels[c]] * n_batch
samples = doc2tensor(support_x, word2index, max_len)
sample_labels = torch.LongTensor(support_y)
# batches = torch.LongTensor(batch_x)
batches = doc2tensor(batch_x, word2index, max_len)
batch_labels = torch.LongTensor(batch_y)
return samples, sample_labels, batches, batch_labels, classes
# return support_x, support_y, batch_x, batch_y, labels
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def json_dict(path):
doc = open(path, "r", encoding="utf-8").read().splitlines()
counter = count_doc(doc)
with open("data/counter.json", "w", encoding="utf-8") as f:
json.dump(counter, f, ensure_ascii=False)
dict = counter2dict(counter)
print(dict)
labels = {}
for line in doc:
line = line.split("\t")[0]
if line not in labels:
labels[line] = len(labels)
print(len(labels))
def genData(path, outpath):
f = open(path, "r", encoding="utf-8")
# doc = open(path, "r", encoding="utf-8").read().splitlines()
# print(path, " 样本数量 ", len(doc))
doc2 = []
count = {}
# 6552400379030536455_!_101_!_news_culture_!_上联:老子骑牛读书,下联怎么对?_!_
for line in f:
if "\t" in line:
continue
words = line.split("|,|");
if len(words) != 5:
continue
catogory = words[1].strip()
sentence = words[2].strip()
if ',' in catogory:
catogory = catogory.split(',')[0]
if '/' in catogory:
catogory = catogory.split('/')[-1]
if not catogory or not sentence:
continue
if catogory not in count:
count[catogory] = 1
else:
if count[catogory] >= 1000:
continue
count[catogory] += 1
doc2.append([catogory, sentence])
print(count)
f = open(outpath, "w", encoding="utf-8")
for pair in doc2:
if count[pair[0]] < 100:
continue
line = '\t'.join(pair)
f.write(line + '\n')
# f.writelines("\n".join(doc2))
f.close()
if __name__ == "__main__":
t0 = time()
path0 = "../../data/toutiao-text-classfication-dataset/toutiao_cat_data.txt"
path1 = "../../data/toutiao-multilevel-text-classfication-dataset/mlc_dataset.txt"
path = "data/toutiao.txt"
genData(path1, path)
# json_dict(path)
print("耗时", time() - t0)