diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index 906ce72f0..03131b75c 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -191,10 +191,7 @@ def prepare_deepspeed(accelerator, model): return model -def main(): - parser = ArgumentParserPlus((FlatArguments)) - args = parser.parse() - +def main(args: FlatArguments): # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment @@ -657,4 +654,6 @@ def load_model(): if __name__ == "__main__": - main() + parser = ArgumentParserPlus((FlatArguments)) + args = parser.parse() + main(args) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 77ebc6f7c..39c49ffb3 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -171,10 +171,7 @@ def save_with_accelerate(accelerator, model, tokenizer, output_dir, args): ) -def main(): - parser = ArgumentParserPlus((FlatArguments)) - args = parser.parse() - +def main(args: FlatArguments): # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment @@ -690,4 +687,6 @@ def main(): if __name__ == "__main__": - main() + parser = ArgumentParserPlus((FlatArguments)) + args = parser.parse() + main(args)