-
Notifications
You must be signed in to change notification settings - Fork 4
/
config.py
132 lines (112 loc) · 4.92 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import argparse
import torch
import os
import random
import json
import numpy as np
class Config(object):
def __init__(self):
# get init config
args = self.__get_config()
for key in args.__dict__:
setattr(self, key, args.__dict__[key])
# select device
self.device = None
if self.cuda >= 0 and torch.cuda.is_available():
self.device = torch.device('cuda:{}'.format(self.cuda))
else:
self.device = torch.device('cpu')
# determine the model name and model dir
if self.model_name is None:
self.model_name = 'CNN'
self.model_dir = os.path.join(self.output_dir, self.model_name)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
# backup data
self.__config_backup(args)
# set the random seed
self.__set_seed(self.seed)
def __get_config(self):
parser = argparse.ArgumentParser()
parser.description = 'config for models'
# several key selective parameters
parser.add_argument('--data_dir', type=str,
default='./data',
help='dir to load data')
parser.add_argument('--output_dir', type=str,
default='./output',
help='dir to save output')
# word embedding
parser.add_argument('--embedding_path', type=str,
default='./embedding/hlbl-embeddings-scaled.EMBEDDING_SIZE=50.txt',
help='pre_trained word embedding')
parser.add_argument('--word_dim', type=int,
default=50,
help='dimension of word embedding')
# train settings
parser.add_argument('--model_name', type=str,
default=None,
help='model name')
parser.add_argument('--mode', type=int,
default=1,
choices=[0, 1],
help='running mode: 1 for training; otherwise testing')
parser.add_argument('--seed', type=int,
default=5782,
help='random seed')
parser.add_argument('--cuda', type=int,
default=0,
help='num of gpu device, if -1, select cpu')
parser.add_argument('--epoch', type=int,
default=20,
help='max epoches during training')
# hyper parameters
parser.add_argument('--dropout', type=float,
default=0.5,
help='the possiblity of dropout')
parser.add_argument('--batch_size', type=int,
default=128,
help='batch size')
parser.add_argument('--lr', type=float,
default=0.001,
help='learning rate')
parser.add_argument('--max_len', type=int,
default=100,
help='max length of sentence')
parser.add_argument('--pos_dis', type=int, default=50,
help='max distance of position embedding')
parser.add_argument('--pos_dim', type=int,
default=5,
help='dimension of position embedding')
parser.add_argument('--hidden_size', type=int, default=100,
help='the size of linear layer between convolution and classification')
# hyper parameters for cnn
parser.add_argument('--filter_num', type=int, default=200,
help='the number of filters in convolution')
parser.add_argument('--window', type=int, default=3,
help='the size of window in convolution')
parser.add_argument('--L2_decay', type=float, default=1e-5,
help='L2 weight decay')
args = parser.parse_args()
return args
def __set_seed(self, seed=1234):
os.environ['PYTHONHASHSEED'] = '{}'.format(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # set seed for cpu
torch.cuda.manual_seed(seed) # set seed for current gpu
torch.cuda.manual_seed_all(seed) # set seed for all gpu
def __config_backup(self, args):
config_backup_path = os.path.join(self.model_dir, 'config.json')
with open(config_backup_path, 'w', encoding='utf-8') as fw:
json.dump(vars(args), fw, ensure_ascii=False)
def print_config(self):
for key in self.__dict__:
print(key, end=' = ')
print(self.__dict__[key])
if __name__ == '__main__':
config = Config()
config.print_config()