-
Notifications
You must be signed in to change notification settings - Fork 54
/
main_task_caption.py
689 lines (562 loc) · 32.8 KB
/
main_task_caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch
from torch.utils.data import (SequentialSampler)
import numpy as np
import random
import os
from collections import OrderedDict
from nlgeval import NLGEval
import time
import argparse
from modules.tokenization import BertTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import UniVL
from modules.optimization import BertAdam
from modules.beam import Beam
from torch.utils.data import DataLoader
from dataloaders.dataloader_youcook_caption import Youcook_Caption_DataLoader
from dataloaders.dataloader_msrvtt_caption import MSRVTT_Caption_DataLoader
from util import get_logger
torch.distributed.init_process_group(backend="nccl")
global logger
def get_args(description='UniVL on Caption Task'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument('--train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--data_path', type=str, default='data/youcookii_caption_transcript.pickle',
help='caption and transcription pickle file path')
parser.add_argument('--features_path', type=str, default='data/youcookii_videos_feature.pickle',
help='feature path for 2D features')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True, help="Bert pre-trained model")
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `caption` to finetune.")
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
parser.add_argument('--stage_two', action='store_true', help="Whether training with decoder.")
args = parser.parse_args()
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
return args
def set_seed_logger(args):
global logger
# predefining random initial seeds
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(args.local_rank)
args.world_size = world_size
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
if args.local_rank == 0:
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
return args
def init_device(args, local_rank):
global logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
args.n_gpu = n_gpu
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
return device, n_gpu
def init_model(args, device, n_gpu, local_rank):
if args.init_model:
model_state_dict = torch.load(args.init_model, map_location='cpu')
else:
model_state_dict = None
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
return model
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
if hasattr(model, 'module'):
model = model.module
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
optimizer_grouped_parameters = [
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
]
scheduler = None
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
max_grad_norm=1.0)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
return optimizer, scheduler, model
def dataloader_youcook_train(args, tokenizer):
youcook_dataset = Youcook_Caption_DataLoader(
csv=args.train_csv,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(youcook_dataset)
dataloader = DataLoader(
youcook_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(youcook_dataset), train_sampler
def dataloader_youcook_test(args, tokenizer):
youcook_testset = Youcook_Caption_DataLoader(
csv=args.val_csv,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
)
test_sampler = SequentialSampler(youcook_testset)
dataloader_youcook = DataLoader(
youcook_testset,
sampler=test_sampler,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
pin_memory=False,
)
if args.local_rank == 0:
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
return dataloader_youcook, len(youcook_testset)
def dataloader_msrvtt_train(args, tokenizer):
msrvtt_dataset = MSRVTT_Caption_DataLoader(
csv_path=args.train_csv,
json_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
split_type="train",
)
train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset)
dataloader = DataLoader(
msrvtt_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(msrvtt_dataset), train_sampler
def dataloader_msrvtt_test(args, tokenizer, split_type="test",):
msrvtt_testset = MSRVTT_Caption_DataLoader(
csv_path=args.val_csv,
json_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
split_type=split_type,
)
test_sampler = SequentialSampler(msrvtt_testset)
dataloader_msrvtt = DataLoader(
msrvtt_testset,
sampler=test_sampler,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
pin_memory=False,
drop_last=False,
)
return dataloader_msrvtt, len(msrvtt_testset)
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_model(epoch, args, model, type_name=""):
# Only save the model it-self
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Model saved to %s", output_model_file)
return output_model_file
def load_model(epoch, args, n_gpu, device, model_file=None):
if model_file is None or len(model_file) == 0:
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
if os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
if args.local_rank == 0:
logger.info("Model loaded from %s", model_file)
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
else:
model = None
return model
def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer, scheduler,
global_step, nlgEvalObj=None, local_rank=0):
global logger
torch.cuda.empty_cache()
model.train()
log_step = args.n_display
start_time = time.time()
total_loss = 0
for step, batch in enumerate(train_dataloader):
# if n_gpu == 1:
# # multi-gpu does scattering it-self
# batch = tuple(t.to(device) for t in batch)
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index,\
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
masked_video=masked_video, video_labels_index=video_labels_index,
input_caption_ids=pairs_input_caption_ids, decoder_mask=pairs_decoder_mask,
output_caption_ids=pairs_output_caption_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scheduler is not None:
scheduler.step() # Update learning rate schedule
optimizer.step()
optimizer.zero_grad()
global_step += 1
if global_step % log_step == 0 and local_rank == 0:
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
args.epochs, step + 1,
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
float(loss),
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
start_time = time.time()
total_loss = total_loss / len(train_dataloader)
return total_loss, global_step
# ---------------------------------------->
def get_inst_idx_to_tensor_position_map(inst_idx_list):
''' Indicate the position of an instance in a tensor. '''
return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
''' Collect tensor parts associated to active instances. '''
_, *d_hs = beamed_tensor.size()
n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs)
beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
beamed_tensor = beamed_tensor.view(*new_shape)
return beamed_tensor
def collate_active_info(input_tuples, inst_idx_to_position_map, active_inst_idx_list, n_bm, device):
assert isinstance(input_tuples, tuple)
sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt = input_tuples
# Sentences which are still active are collected,
# so the decoder will not run on completed sentences.
n_prev_active_inst = len(inst_idx_to_position_map)
active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
active_inst_idx = torch.LongTensor(active_inst_idx).to(device)
active_sequence_output_rpt = collect_active_part(sequence_output_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_visual_output_rpt = collect_active_part(visual_output_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_input_ids_rpt = collect_active_part(input_ids_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_input_mask_rpt = collect_active_part(input_mask_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_video_mask_rpt = collect_active_part(video_mask_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
return (active_sequence_output_rpt, active_visual_output_rpt, active_input_ids_rpt, active_input_mask_rpt, active_video_mask_rpt), \
active_inst_idx_to_position_map
def beam_decode_step(decoder, inst_dec_beams, len_dec_seq,
inst_idx_to_position_map, n_bm, device, input_tuples, decoder_length=None):
assert isinstance(input_tuples, tuple)
''' Decode and update beam status, and then return active beam idx'''
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
dec_partial_seq = torch.stack(dec_partial_seq).to(device)
dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
return dec_partial_seq
def predict_word(next_decoder_ids, n_active_inst, n_bm, device, input_tuples):
sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt = input_tuples
next_decoder_mask = torch.ones(next_decoder_ids.size(), dtype=torch.uint8).to(device)
dec_output = decoder(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt,
video_mask_rpt, next_decoder_ids, next_decoder_mask, shaped=True, get_logits=True)
dec_output = dec_output[:, -1, :]
word_prob = torch.nn.functional.log_softmax(dec_output, dim=1)
word_prob = word_prob.view(n_active_inst, n_bm, -1)
return word_prob
def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map, decoder_length=None):
active_inst_idx_list = []
for inst_idx, inst_position in inst_idx_to_position_map.items():
if decoder_length is None:
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
else:
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position], word_length=decoder_length[inst_idx])
if not is_inst_complete:
active_inst_idx_list += [inst_idx]
return active_inst_idx_list
n_active_inst = len(inst_idx_to_position_map)
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
word_prob = predict_word(dec_seq, n_active_inst, n_bm, device, input_tuples)
# Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list(inst_dec_beams, word_prob, inst_idx_to_position_map,
decoder_length=decoder_length)
return active_inst_idx_list
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
all_hyp, all_scores = [], []
for inst_idx in range(len(inst_dec_beams)):
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
all_scores += [scores[:n_best]]
hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
all_hyp += [hyps]
return all_hyp, all_scores
# >----------------------------------------
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=None, test_set=None):
if hasattr(model, 'module'):
model = model.module.to(device)
if model._stage_one:
return 0.
all_result_lists = []
all_caption_lists = []
model.eval()
for batch in test_dataloader:
batch = tuple(t.to(device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
with torch.no_grad():
sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
# -- Repeat data for beam search
n_bm = 5 # beam_size
device = sequence_output.device
n_inst, len_s, d_h = sequence_output.size()
_, len_v, v_h = visual_output.size()
decoder = model.decoder_caption
# Note: shaped first, then decoder need the parameter shaped=True
input_ids = input_ids.view(-1, input_ids.shape[-1])
input_mask = input_mask.view(-1, input_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
visual_output_rpt = visual_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_v, v_h)
input_ids_rpt = input_ids.repeat(1, n_bm).view(n_inst * n_bm, len_s)
input_mask_rpt = input_mask.repeat(1, n_bm).view(n_inst * n_bm, len_s)
video_mask_rpt = video_mask.repeat(1, n_bm).view(n_inst * n_bm, len_v)
# -- Prepare beams
inst_dec_beams = [Beam(n_bm, device=device, tokenizer=tokenizer) for _ in range(n_inst)]
# -- Bookkeeping for active or not
active_inst_idx_list = list(range(n_inst))
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
# -- Decode
for len_dec_seq in range(1, args.max_words + 1):
active_inst_idx_list = beam_decode_step(decoder, inst_dec_beams,
len_dec_seq, inst_idx_to_position_map, n_bm, device,
(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt))
if not active_inst_idx_list:
break # all instances have finished their path to <EOS>
(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt), \
inst_idx_to_position_map = collate_active_info((sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt),
inst_idx_to_position_map, active_inst_idx_list, n_bm, device)
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1)
result_list = [batch_hyp[i][0] for i in range(n_inst)]
pairs_output_caption_ids = pairs_output_caption_ids.view(-1, pairs_output_caption_ids.shape[-1])
caption_list = pairs_output_caption_ids.cpu().detach().numpy()
for re_idx, re_list in enumerate(result_list):
decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
if "[SEP]" in decode_text_list:
SEP_index = decode_text_list.index("[SEP]")
decode_text_list = decode_text_list[:SEP_index]
if "[PAD]" in decode_text_list:
PAD_index = decode_text_list.index("[PAD]")
decode_text_list = decode_text_list[:PAD_index]
decode_text = ' '.join(decode_text_list)
decode_text = decode_text.replace(" ##", "").strip("##").strip()
all_result_lists.append(decode_text)
for re_idx, re_list in enumerate(caption_list):
decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
if "[SEP]" in decode_text_list:
SEP_index = decode_text_list.index("[SEP]")
decode_text_list = decode_text_list[:SEP_index]
if "[PAD]" in decode_text_list:
PAD_index = decode_text_list.index("[PAD]")
decode_text_list = decode_text_list[:PAD_index]
decode_text = ' '.join(decode_text_list)
decode_text = decode_text.replace(" ##", "").strip("##").strip()
all_caption_lists.append(decode_text)
# Save full results
if test_set is not None and hasattr(test_set, 'iter2video_pairs_dict'):
hyp_path = os.path.join(args.output_dir, "hyp_complete_results.txt")
with open(hyp_path, "w", encoding='utf-8') as writer:
writer.write("{}\t{}\t{}\n".format("video_id", "start_time", "caption"))
for idx, pre_txt in enumerate(all_result_lists):
video_id, sub_id = test_set.iter2video_pairs_dict[idx]
start_time = test_set.data_dict[video_id]['start'][sub_id]
writer.write("{}\t{}\t{}\n".format(video_id, start_time, pre_txt))
logger.info("File of complete results is saved in {}".format(hyp_path))
# Save pure results
hyp_path = os.path.join(args.output_dir, "hyp.txt")
with open(hyp_path, "w", encoding='utf-8') as writer:
for pre_txt in all_result_lists:
writer.write(pre_txt+"\n")
ref_path = os.path.join(args.output_dir, "ref.txt")
with open(ref_path, "w", encoding='utf-8') as writer:
for ground_txt in all_caption_lists:
writer.write(ground_txt + "\n")
if args.datatype == "msrvtt":
all_caption_lists = []
sentences_dict = test_dataloader.dataset.sentences_dict
video_sentences_dict = test_dataloader.dataset.video_sentences_dict
for idx in range(len(sentences_dict)):
video_id, _ = sentences_dict[idx]
sentences = video_sentences_dict[video_id]
all_caption_lists.append(sentences)
all_caption_lists = [list(itms) for itms in zip(*all_caption_lists)]
else:
all_caption_lists = [all_caption_lists]
# Evaluate
metrics_nlg = nlgEvalObj.compute_metrics(ref_list=all_caption_lists, hyp_list=all_result_lists)
logger.info(">>> BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}".
format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"], metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"]))
logger.info(">>> METEOR: {:.4f}, ROUGE_L: {:.4f}, CIDEr: {:.4f}".format(metrics_nlg["METEOR"], metrics_nlg["ROUGE_L"], metrics_nlg["CIDEr"]))
Bleu_4 = metrics_nlg["Bleu_4"]
return Bleu_4
DATALOADER_DICT = {}
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test}
def main():
global logger
args = get_args()
args = set_seed_logger(args)
device, n_gpu = init_device(args, args.local_rank)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = init_model(args, device, n_gpu, args.local_rank)
assert args.task_type == "caption"
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
assert args.datatype in DATALOADER_DICT
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer)
if args.local_rank == 0:
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_train:
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
if args.local_rank == 0:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
best_score = 0.00001
best_output_model_file = None
global_step = 0
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=args.local_rank)
if args.local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, type_name="")
if epoch > 0:
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
if best_score <= Bleu_4:
best_score = Bleu_4
best_output_model_file = output_model_file
logger.info("The best model is: {}, the Bleu_4 is: {:.4f}".format(best_output_model_file, best_score))
else:
logger.warning("Skip the evaluation after {}-th epoch.".format(epoch+1))
if args.local_rank == 0:
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
elif args.do_eval:
if args.local_rank == 0:
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
if __name__ == "__main__":
main()