-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_trajectory.py
66 lines (58 loc) · 3.16 KB
/
plot_trajectory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Plot the optimization path in the space spanned by principle directions.
"""
import numpy as np
import torch
import copy
import math
import h5py
import os
import argparse
import model_loader
import net_plotter
from projection import setup_PCA_directions, project_trajectory
import plot_2D
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Plot optimization trajectory')
parser.add_argument('--dataset', default='cifar10', help='dataset')
parser.add_argument('--model', default='resnet56', help='trained models')
parser.add_argument('--model_folder', default='', help='folders for models to be projected')
parser.add_argument('--dir_type', default='weights',
help="""direction type: weights (all weights except bias and BN paras) |
states (include BN.running_mean/var)""")
parser.add_argument('--ignore', default='', help='ignore bias and BN paras: biasbn (no bias or bn)')
parser.add_argument('--prefix', default='model_', help='prefix for the checkpint model')
parser.add_argument('--suffix', default='.t7', help='prefix for the checkpint model')
parser.add_argument('--start_epoch', default=0, type=int, help='min index of epochs')
parser.add_argument('--max_epoch', default=300, type=int, help='max number of epochs')
parser.add_argument('--save_epoch', default=1, type=int, help='save models every few epochs')
parser.add_argument('--dir_file', default='', help='load the direction file for projection')
args = parser.parse_args()
#--------------------------------------------------------------------------
# load the final model
#--------------------------------------------------------------------------
last_model_file = args.model_folder + '/' + args.prefix + str(args.max_epoch) + args.suffix
net = model_loader.load(args.dataset, args.model, last_model_file)
w = net_plotter.get_weights(net)
s = net.state_dict()
#--------------------------------------------------------------------------
# collect models to be projected
#--------------------------------------------------------------------------
model_files = []
for epoch in range(args.start_epoch, args.max_epoch + args.save_epoch, args.save_epoch):
model_file = args.model_folder + '/' + args.prefix + str(epoch) + args.suffix
assert os.path.exists(model_file), 'model %s does not exist' % model_file
model_files.append(model_file)
#--------------------------------------------------------------------------
# load or create projection directions
#--------------------------------------------------------------------------
if args.dir_file:
dir_file = args.dir_file
else:
dir_file = setup_PCA_directions(args, model_files, w, s)
#--------------------------------------------------------------------------
# projection trajectory to given directions
#--------------------------------------------------------------------------
proj_file = project_trajectory(dir_file, w, s, args.dataset, args.model,
model_files, args.dir_type, 'cos')
plot_2D.plot_trajectory(proj_file, dir_file)