Skip to content

Commit

Permalink
修复bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yizt committed Jan 11, 2020
1 parent 07e461b commit 3744412
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ def print(*args, **kwargs):

def init_distributed_mode(args):
args.distributed = False
if args.device == 'cuda' and torch.cuda.device_count() > 1:
if args.device == 'cuda' and 'WORLD_SIZE' in os.environ:
args.distributed = True
args.world_size = torch.cuda.device_count()
args.rank = os.environ['RANK']
args.world_size = int(os.environ['WORLD_SIZE'])
args.rank = int(os.environ['RANK'])
torch.cuda.set_device(args.local_rank)

# args.local_rank, os.environ["RANK"],os.environ['WORLD_SIZE'] 会自动赋值
# print("args.local_rank:{},RANK:{},WORLD_SIZE:{}".format(args.local_rank, os.environ["RANK"],
# os.environ['WORLD_SIZE']))
print("args.local_rank:{},RANK:{},WORLD_SIZE:{}".format(args.local_rank, os.environ["RANK"],
os.environ['WORLD_SIZE']))
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.local_rank == 0)
setup_for_distributed(args.rank == 0)


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args):
Expand Down Expand Up @@ -112,7 +112,9 @@ def train(args):

model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank)
model_without_ddp = model.module
# 加载预训练模型
if args.init_epoch > 0:
Expand Down Expand Up @@ -158,7 +160,8 @@ def train(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default='cpu', help="cpu or cuda")
parser.add_argument("--direction", type=str, default='horizontal', help="horizontal or vertical")
parser.add_argument("--direction", type=str, choices=['horizontal', 'vertical'],
default='horizontal', help="horizontal or vertical")
parser.add_argument("--batch-size", type=int, default=64, help="batch size")
parser.add_argument("--epochs", type=int, default=90, help="epochs")
parser.add_argument("--init-epoch", type=int, default=0, help="init epoch")
Expand Down

0 comments on commit 3744412

Please sign in to comment.