-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
163 lines (132 loc) · 4.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from __future__ import division, print_function, unicode_literals
import sys
import os.path
from torch import nn
def fwrite(new_doc, path, mode='w', no_overwrite=False):
if not path:
print("[Info] Path does not exist in fwrite():", str(path))
return
if no_overwrite and os.path.isfile(path):
print("[Error] pls choose whether to continue, as file already exists:",
path)
import pdb
pdb.set_trace()
return
with open(path, mode) as f:
f.write(new_doc)
def show_time(what_happens='', cat_server=False, printout=True):
import datetime
disp = '⏰ Time: ' + \
datetime.datetime.now().strftime('%m%d%H%M-%S')
disp = disp + '\t' + what_happens if what_happens else disp
if printout:
print(disp)
curr_time = datetime.datetime.now().strftime('%m%d%H%M')
if cat_server:
hostname = socket.gethostname()
prefix = "rosetta"
if hostname.startswith(prefix):
host_id = hostname[len(prefix):]
try:
host_id = int(host_id)
host_id = "{:02d}".format(host_id)
except:
pass
hostname = prefix[0] + host_id
else:
hostname = hostname[0]
curr_time += hostname
return curr_time
def show_var(expression,
joiner='\n', print=print):
'''
Prints out the name and value of variables.
Eg. if a variable with name `num` and value `1`,
it will print "num: 1\n"
Parameters
----------
expression: ``List[str]``, required
A list of varible names string.
Returns
----------
None
'''
import json
var_output = []
for var_str in expression:
frame = sys._getframe(1)
value = eval(var_str, frame.f_globals, frame.f_locals)
if ' object at ' in repr(value):
value = vars(value)
value = json.dumps(value, indent=2)
var_output += ['{}: {}'.format(var_str, value)]
else:
var_output += ['{}: {}'.format(var_str, repr(value))]
if joiner != '\n':
output = "[Info] {}".format(joiner.join(var_output))
else:
output = joiner.join(var_output)
print(output)
return output
def shell(cmd, working_directory='.', stdout=False, stderr=False):
import subprocess
from subprocess import PIPE, Popen
subp = Popen(cmd, shell=True, stdout=PIPE,
stderr=subprocess.STDOUT, cwd=working_directory)
subp_stdout, subp_stderr = subp.communicate()
if subp_stdout: subp_stdout = subp_stdout.decode("utf-8")
if subp_stderr: subp_stderr = subp_stderr.decode("utf-8")
if stdout and subp_stdout:
print("[stdout]", subp_stdout, "[end]")
if stderr and subp_stderr:
print("[stderr]", subp_stderr, "[end]")
return subp_stdout, subp_stderr
def set_seed(seed=0):
import random
if seed is None:
from efficiency.log import show_time
seed = int(show_time())
print("[Info] seed set to: {}".format(seed))
random.seed(seed)
try:
import numpy as np
np.random.seed(seed)
except ImportError:
pass
try:
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
except ImportError:
pass
class NLP:
def __init__(self):
import spacy
self.nlp = spacy.load('en', disable=['ner', 'parser', 'tagger'])
self.nlp.add_pipe(self.nlp.create_pipe('sentencizer'))
def sent_tokenize(self, text):
doc = self.nlp(text)
sentences = [sent.string.strip() for sent in doc.sents]
return sentences
def word_tokenize(self, text, lower=False): # create a tokenizer function
if text is None: return text
text = ' '.join(text.split())
if lower: text = text.lower()
toks = [tok.text for tok in self.nlp.tokenizer(text)]
return ' '.join(toks)
def sent_bleu(ref_list, hyp):
from nltk.translate import bleu
from nltk.translate.bleu_score import SmoothingFunction
smoothie = SmoothingFunction().method4
refs = [ref.split() for ref in ref_list]
hyp = hyp.split()
return bleu(refs, hyp, smoothing_function=smoothie)
def init_weights(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
elif type(m) == nn.LSTM:
for name, param in m.named_parameters():
if len(param.shape) > 1:
nn.init.xavier_uniform_(param)