-
Notifications
You must be signed in to change notification settings - Fork 726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite of the load_checkpoint function #650
base: main
Are you sure you want to change the base?
Conversation
Will test this on azure soon. |
Heads up: #646 will likely go in first since tests are passing there (after loss parity check is added). There will probably be merge conflicts after, but hopefully not too bad. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit on getting the checkpoint saving/uploading tests to pass (may need to modify to adjust to new logic, just need to be clear why the modification to the test is needed)
This is a rewrite of how we determine which checkpoint to load when starting/restarting a training run. (Originally there was also a refactor of how our different checkpoint paths are processed, but I seperated this out for now)
Previously we had a quite brittle logic for this, with edge cases where metaseq would not load the correct checkpoint. For example here: #544
The new logic checks first all possible sources for checkpoints (restore-file, finetune-from, local checkpoints, nfs / azure checkpoints), and assigns them priority based on their progress in training and prefers local caches.
It then takes the most recent checkpoint and copies it to local disk.
To test this you need both metaseq / metaseq-internal PR's. Here: https://github.com/fairinternal/metaseq-internal/pull/842
I tested:
What I didn't test yet, is if starting with an azure blob path is working.