-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnn_30K_ternary.py
54 lines (49 loc) · 1.75 KB
/
cnn_30K_ternary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''
cnn_org: 1D CNN with original data
data: 30K
objective: ternary
include mid point: True
'''
import torch
from xrd_analyzer.data.data_loader_30K import get_data_loader
from xrd_analyzer.models.cnn_classification import CNN
from xrd_analyzer.training.train import Trainer
from pathlib import Path
import json
identifier = 'cnn_30K_ternary_true'
objective = 'ternary'
save_path = Path(__file__).resolve().parent / "outputs" / identifier
if not save_path.exists():
save_path.mkdir()
arg_dict = {'dataloader': {
'data_ratio': [0.75, 0.15, 0.10],
'batch_size': 256,
'objective': objective,
'include_mid_point': True,
'save_path': save_path,
'random_state': 42},
'model': {
'objective': objective},
'train': {
'objective': objective,
'save_model': True,
'save_path': save_path,
'model_id': identifier}}
# dataloader
train_data_loader, val_data_loader, test_data_loader = get_data_loader(
**arg_dict['dataloader'])
# model + loss + optimizer
model = CNN(**arg_dict['model'])
loss = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# training
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
trainer = Trainer(model, optimizer, loss, device, **arg_dict['train'])
trainer.train(train_data_loader, val_data_loader,
test_data_loader, epochs=50)
for k in arg_dict:
if 'save_path' in arg_dict[k]:
arg_dict[k]['save_path'] = str(save_path)
with open(save_path / "args_dict.json", 'w+') as f:
json.dump(arg_dict, f)