Skip to content

Commit

Permalink
[fix] use ddp to calculate loss of dev set (ZhengkunTian#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
whiteshirt0429 authored Dec 18, 2020
1 parent 107b478 commit 149889f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
20 changes: 18 additions & 2 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@
rank=args.rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, shuffle=True)
cv_sampler = torch.utils.data.distributed.DistributedSampler(
cv_dataset, shuffle=False)
else:
train_sampler = None
cv_sampler = None

train_data_loader = DataLoader(train_dataset,
collate_fn=collate_func,
Expand All @@ -105,9 +108,10 @@
num_workers=args.num_workers)
cv_data_loader = DataLoader(cv_dataset,
collate_fn=cv_collate_func,
sampler=cv_sampler,
shuffle=False,
batch_size=1,
num_workers=0)
num_workers=args.num_workers)

# Init transformer model
input_dim = train_dataset.input_dim
Expand Down Expand Up @@ -193,7 +197,19 @@
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, scheduler, train_data_loader, device,
writer, configs)
cv_loss = executor.cv(model, cv_data_loader, device, configs)
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, configs)
if args.world_size > 1:
# all_reduce expected a sequence parameter, so we use [num_seen_utts].
num_seen_utts = torch.Tensor([num_seen_utts]).to(device)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = torch.Tensor([total_loss]).to(device)
dist.all_reduce(total_loss)
cv_loss = total_loss[0] / num_seen_utts[0]
cv_loss = cv_loss.item()
else:
cv_loss = total_loss / num_seen_utts

logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
Expand Down
4 changes: 2 additions & 2 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class Executor:

def __init__(self):
self.step = 0

Expand Down Expand Up @@ -84,4 +83,5 @@ def cv(self, model, data_loader, device, args):
batch_idx, num_total_batch, loss.item(),
loss_att.item(), loss_ctc.item(),
total_loss / num_seen_utts))
return total_loss / num_seen_utts

return total_loss, num_seen_utts

0 comments on commit 149889f

Please sign in to comment.