diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index b6c877411..a2b5d4e29 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -577,7 +577,7 @@ def main(args: FlatArguments): args.dataset_mixer, configs=args.dataset_config_name, splits=["train"], - save_data_dir=args.dataset_mix_dir, + save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, columns_to_keep=["chosen", "rejected"], ) elif args.dataset_mixer_list is not None: @@ -586,7 +586,7 @@ def main(args: FlatArguments): args.dataset_mixer_list, configs=args.dataset_config_name, splits=["train"], - save_data_dir=args.dataset_mix_dir, + save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, columns_to_keep=["chosen", "rejected"], ) else: diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 970f24921..5d3ca8fe2 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -526,7 +526,7 @@ def main(args: FlatArguments): args.dataset_mixer, configs=args.dataset_config_name, splits=["train"], - save_data_dir=args.dataset_mix_dir, + save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, columns_to_keep=["messages"], ) elif args.dataset_mixer_list is not None: @@ -535,7 +535,7 @@ def main(args: FlatArguments): args.dataset_mixer_list, configs=args.dataset_config_name, splits=["train"], - save_data_dir=args.dataset_mix_dir, + save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, columns_to_keep=["messages"], ) else: