-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
47 lines (36 loc) · 1.27 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import torch
from config import args
from exp.exp_main import ExpTransplit
from exp.exp_diffusion import ExpDiffusion
import random
import time
import numpy as np
fix_seed = 2024
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)
if __name__ == "__main__":
print('Args in experiment:')
print(args)
if args.model in ['Dedipeak']:
Exp = ExpDiffusion
else:
Exp = ExpTransplit
if args.training:
for ii in range(args.itr):
timecode = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time()))
setting = f'{args.model}_{args.data_path.split("/")[-1][:-4]}_{timecode}'
exp = Exp(args) # set experiments
print('---- training: {} ----'.format(setting))
exp.train(setting)
print('---- testing: {} ----'.format(setting))
exp.test(setting)
torch.cuda.empty_cache()
else:
for ii in range(args.itr):
setting = f'{args.data_path[:-4]}_{args.model}_{args.pred_len}_{args.loss}_{ii}'
exp = Exp(args) # set experiments
print('---- testing: {} ----'.format(setting))
results = exp.test(setting, f"./checkpoints/{setting}/checkpoint.pth")
torch.cuda.empty_cache()