diff --git a/psiflow/models/mace_utils.py b/psiflow/models/mace_utils.py index 39a106e..7db5515 100644 --- a/psiflow/models/mace_utils.py +++ b/psiflow/models/mace_utils.py @@ -690,10 +690,12 @@ def main(): type=str, ) args = parser.parse_args() - if args.distributed: - world_size = torch.cuda.device_count() + world_size = torch.cuda.device_count() + if world_size > 1: + args.distributed = True import torch.multiprocessing as mp mp.spawn(run, args=(args, world_size), nprocs=world_size) else: + args.distributed = False run(0, args, 1)