forked from ShawnXiha/AI-challenger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
word2vec.py
72 lines (60 loc) · 2.37 KB
/
word2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from gensim.models import Word2Vec
from multiprocessing import cpu_count
from tqdm import tqdm
import re
import pickle
import jieba
import numpy as np
re_apos = re.compile(r"(\w)'s\b") # make 's a separate word
re_mw_punc = re.compile(r"(\w[’'])(\w)") # other ' in a word creates 2 words
re_punc = re.compile("([\"().,;:/_?!—])") # add spaces around punctuation
re_mult_space = re.compile(r" *") # replace multiple spaces with just one
embedding_dim = 300
def simple_toks(sent):
sent = re_apos.sub(r"\1 's", sent)
sent = re_mw_punc.sub(r"\1 \2", sent)
sent = re_punc.sub(r" \1 ", sent).replace('-', ' ')
sent = re_mult_space.sub(' ', sent)
return sent.lower().split()
class MySentences(object):
def __init__(self, dirname):
self.dirname = dirname
def __iter__(self):
for line in tqdm(open(self.dirname)):
yield simple_toks(line)
class MyChineseSentences(object):
def __init__(self, dirname):
self.dirname = dirname
def __iter__(self):
for line in tqdm(open(self.dirname, encoding='utf-8')):
yield list(jieba.cut(line))
def train_save_wordvec_export(lan ='en'):
output = 'input/'+lan
if lan == 'en':
corpus = MySentences('input/train.en')
else:
corpus = MyChineseSentences('input/train.zh')
model = Word2Vec(corpus, size=embedding_dim, min_count=10, sg=1, workers=cpu_count())
print('model trained!')
vocabulary = model.wv.vocab
words = list(vocabulary.keys())
# 句子开头结尾
function_words = ['PADDING', 'START', 'END', 'OOV_WORD']
words = function_words + words
_word2id = dict(zip(words, range(len(words))))
input_dim = len(words)
embeddings = []
for word in vocabulary:
embeddings.append(model[word])
embeddings = np.array(embeddings, dtype=np.float32)
weights = np.zeros((input_dim, embedding_dim), np.float32)
weights[1] = np.ones(embedding_dim, np.float32) * 0.33 # START
weights[2] = np.ones(embedding_dim, np.float32) * 0.66 # END
weights[3] = np.average(embeddings, axis=0)
weights[4:] = embeddings # 初始化FUNCTION_WORDS以外的单词
pickle.dump(_word2id, open(output+'vocab.pkl', "wb"))
np.save(output+'word2vec.npy', embeddings)
print('word vector saved!')
if __name__ == '__main__':
train_save_wordvec_export(lan='en')
train_save_wordvec_export(lan='zh')