diff --git a/train.py b/train.py index 7447424..bd4e0f9 100644 --- a/train.py +++ b/train.py @@ -44,20 +44,20 @@ def print(*args, **kwargs): def init_distributed_mode(args): args.distributed = False - if args.device == 'cuda' and torch.cuda.device_count() > 1: + if args.device == 'cuda' and 'WORLD_SIZE' in os.environ: args.distributed = True - args.world_size = torch.cuda.device_count() - args.rank = os.environ['RANK'] + args.world_size = int(os.environ['WORLD_SIZE']) + args.rank = int(os.environ['RANK']) torch.cuda.set_device(args.local_rank) # args.local_rank, os.environ["RANK"],os.environ['WORLD_SIZE'] 会自动赋值 - # print("args.local_rank:{},RANK:{},WORLD_SIZE:{}".format(args.local_rank, os.environ["RANK"], - # os.environ['WORLD_SIZE'])) + print("args.local_rank:{},RANK:{},WORLD_SIZE:{}".format(args.local_rank, os.environ["RANK"], + os.environ['WORLD_SIZE'])) print('| distributed init (rank {}): {}'.format( args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) - setup_for_distributed(args.local_rank == 0) + setup_for_distributed(args.rank == 0) def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args): @@ -112,7 +112,9 @@ def train(args): model_without_ddp = model if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[args.local_rank], + output_device=args.local_rank) model_without_ddp = model.module # 加载预训练模型 if args.init_epoch > 0: @@ -158,7 +160,8 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default='cpu', help="cpu or cuda") - parser.add_argument("--direction", type=str, default='horizontal', help="horizontal or vertical") + parser.add_argument("--direction", type=str, choices=['horizontal', 'vertical'], + default='horizontal', help="horizontal or vertical") parser.add_argument("--batch-size", type=int, default=64, help="batch size") parser.add_argument("--epochs", type=int, default=90, help="epochs") parser.add_argument("--init-epoch", type=int, default=0, help="init epoch")