-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
71 lines (53 loc) · 1.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import time
import torch
import logging
from torchvision.utils import save_image
logger = logging.getLogger('utils')
def save_checkpoint(args, epoch, losses, model, optimizer, best=False):
"""Save a checkpoint."""
checkpoint = {
'epoch': epoch,
'losses': losses,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
if best:
name = f'{args.prefix}-best-epoch-{epoch}-{int(time.time())}.pt'
else:
name = f'{args.prefix}-epoch-{epoch}-{int(time.time())}.pt'
os.makedirs(args.checkpoint_dir, exist_ok=True)
path = os.path.join(args.checkpoint_dir, name)
logger.info(f'Saving checkpoint to "{path}"')
torch.save(checkpoint, path)
def load_checkpoint(args):
"""Fetch and load the best checkpoint if it exists."""
best_model = None
all_models, best_models = [], []
for name in os.listdir(args.checkpoint_dir):
if name.startswith(args.prefix):
if 'best' in name:
best_models.append(name)
else:
all_models.append(name)
if best_models:
best_models.sort(key=lambda x: int(x.split('-')[-1].split('.')[0]))
best_model = best_models[-1]
elif all_models:
all_models.sort(key=lambda x: int(x.split('-')[-1].split('.')[0]))
best_model = all_models[-1]
if best_model:
path = os.path.join(args.checkpoint_dir, best_model)
logger.info(f'Loading checkpoint from "{path}"')
checkpoint = torch.load(path)
return checkpoint
return None
def save_images(args, outputs):
"""Save a batch of images."""
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
logger.info(f'Saving {len(outputs["name"])} images to {args.save_dir}')
for idx, name in enumerate(outputs['name']):
matte = outputs['pred_matte'][idx]
save_path = os.path.join(args.save_dir, name)
save_image(matte, save_path)