diff --git a/wenet/bin/train.py b/wenet/bin/train.py index e6a8c0f..73b697c 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -94,8 +94,11 @@ rank=args.rank) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, shuffle=True) + cv_sampler = torch.utils.data.distributed.DistributedSampler( + cv_dataset, shuffle=False) else: train_sampler = None + cv_sampler = None train_data_loader = DataLoader(train_dataset, collate_fn=collate_func, @@ -105,9 +108,10 @@ num_workers=args.num_workers) cv_data_loader = DataLoader(cv_dataset, collate_fn=cv_collate_func, + sampler=cv_sampler, shuffle=False, batch_size=1, - num_workers=0) + num_workers=args.num_workers) # Init transformer model input_dim = train_dataset.input_dim @@ -193,7 +197,19 @@ logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) executor.train(model, optimizer, scheduler, train_data_loader, device, writer, configs) - cv_loss = executor.cv(model, cv_data_loader, device, configs) + total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, configs) + if args.world_size > 1: + # all_reduce expected a sequence parameter, so we use [num_seen_utts]. + num_seen_utts = torch.Tensor([num_seen_utts]).to(device) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = torch.Tensor([total_loss]).to(device) + dist.all_reduce(total_loss) + cv_loss = total_loss[0] / num_seen_utts[0] + cv_loss = cv_loss.item() + else: + cv_loss = total_loss / num_seen_utts + logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) if args.rank == 0: save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 376e650..a2c7617 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -7,7 +7,6 @@ class Executor: - def __init__(self): self.step = 0 @@ -84,4 +83,5 @@ def cv(self, model, data_loader, device, args): batch_idx, num_total_batch, loss.item(), loss_att.item(), loss_ctc.item(), total_loss / num_seen_utts)) - return total_loss / num_seen_utts + + return total_loss, num_seen_utts