-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add script to convert pickled Llama weights to DCP #634
base: main
Are you sure you want to change the base?
Conversation
…pickled weights to torch DCP
Hi @rlrs! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
@rlrs Looking at the structure of the original weights, this script only covers the 8B model.
|
Ah right, thanks for noticing. Will update that in a bit. |
Here we go, the latest commit hopefully works for both 70B and 405B, if I've understood it correctly. You do have to load up all the shards at once with this approach, I'm not sure there's an easy way around it. Currently testing and making sure that it works with 70B, it takes longer to test it than to write the code... Unfortunately I only have a node with 512 GB RAM so I can't convert or test 405B myself. I do suspect that someone with more knowledge about DCP could potentially process a shard at a time, or even produce multiple shards for more efficient loading. This seems difficult to do without knowing the parallelization scheme you plan to use though. Edit: Verified that 70B works - unsure what a good test would be in the repo, if necessary. I loaded up the checkpoint and ensured that loss on a small dataset is low. |
@rlrs Thanks for the update. Once the TorchTitan team has verified this, I will help by converting all the Llama models and uploading them to Huggingface in DCP format (if the team does not do this). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks good in general.
I do suspect that someone with more knowledge about DCP could potentially process a shard at a time, or even produce multiple shards for more efficient loading.
Discussed with @wz337 offline, and it seems it's not easy to have a fundamentally more memory-efficient implementation, unless one happen to have at least #GPUs the weights are originally saved from.
Verified that 70B works - unsure what a good test would be in the repo, if necessary. I loaded up the checkpoint and ensured that loss on a small dataset is low.
I agree it's not clear what a unit test would look like. Then please just create a brief tutorial in docs/checkpoint.md.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the script. Some general comment:
What's our target model? 8B? 70B? 405B? The tool is likely to be bound by the time and storage not memory.
For 70B, I think couple hours is a reasonable number: (70 * 4 * 1000 / 100) seconds per a single load (torch.load) or save (DCP.save, torch.save). So 1-2 hours should be reasonable.
For 405B, if you use mmap=True
, it is possible to convert 405B with 2 passes, the first pass converts the state_dict
and the second pass saves with DCP. However, with 405B, we may end up with super large file that the file system may not support. And this design mean 4x save and load - (405 * 4 * 1000 / 100) * 4 , roughly 18 hours if nothing goes wrong and the tool must be carefully designed.
For 8B and 70B models, it is reasonable to use this tool. For 405B, using TorchTitan trainer to do the conversion maybe a more reasonable approach.
Some thoughts on the unit test. Open to discussion to see if this is necessary: Maybe we can use the smallest llama weights (3.2 has a 1B) one. We can:
As for the documentation, feel free to make |
Thanks @fegin !
1GB = 1000MB, 100 means 100MB/s for disk read/write. It may not be too slow or too fast depending on the storage.
I meant changing TorchTitan's CheckpointManager to accept some customized function for the first iteration to load customized checkpoint. |
Addressed all the comments on the code now, I think. Will add some information to checkpoint.md in a bit.
I think it is a better approach. But this requires some design. We can land this script first.
Why 16GB? Isn't the dtype f32? Can we also try |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looks almost ready to merge.
- Can you please try @fegin 's suggestion of using
mmap
?
Why 16GB? Isn't the dtype f32? Can we also try mmap option as mentioned in the comment?
I'm a bit confused. If 8B in float32 occupies 32GB, wouldn't this approach need 2x which is 64GB? since it loads all the shards
and then create a new state_dict
, each being a copy of the model.
- Also please add a short tutorial in
docs/checkpoint.md
.
Below are my attempts of converting 70B
✅ 190k -N .metadata
18G -N __0_0.distcp
18G -N __0_1.distcp
18G -N __0_2.distcp
18G -N __0_3.distcp
18G -N __0_4.distcp
18G -N __0_5.distcp
18G -N __0_6.distcp
18G -N __0_7.distcp 405B
❌ shards[i][f"layers.{layer}.attention.{wn}.weight"].view(
RuntimeError: shape '[0, 128, 16384]' is invalid for input of size 16777216 The issue is when you have nested shards for I made changes to this script to consolidate each nested folder, resulting in After processing 105 of 126 layers, I got OOM. (>1.5TB). I worked around it by deallocating memory within the loop after processing each shard. (peaked at ~840GB) ✅ 300k .metadata
103G __0_0.distcp
103G __0_1.distcp
103G __0_2.distcp
103G __0_3.distcp
103G __0_4.distcp
103G __0_5.distcp
102G __0_6.distcp
102G __0_7.distcp
|
@jaysonfrancis Could you add your changes in a comment or as a PR to @rlrs's branch? |
The llama3 original weights are published in bf16, it seems. E.g. the pickle at https://huggingface.co/meta-llama/Llama-3.1-8B/tree/main/original has BFloat16Storage.
@jaysonfrancis I'd appreciate your fixes for loading 405B from HF, however I am unsure if we should support it. 405B on HF is not stored in the original format, which is why my code does not work for it. See here for an explanation: https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/original/mp16/README.md |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Thank you very much for contributing!
I wonder what's the peak memory usage now if we load the 8B model, with the help of mmap
and del
.
Thanks @jaysonfrancis !
I agree that we probably shouldn't land the adaptation to the special treatment from HF. It is OK to keep a branch or even a PR on it. |
torch.load(ckpt, map_location="cpu", weights_only=True, mmap=True) | ||
for ckpt in checkpoint_list | ||
] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Load shards
subdirs = [folder for folder in input_dir.iterdir() if folder.is_dir()]
if subdirs:
checkpoint_folders = sorted(subdirs)
logger.info(f"Loading original Llama weights from {len(checkpoint_folders)} folders")
# Load all .pth files within each folder and treat each folder as a shard
shards = []
for folder in checkpoint_folders:
shard = {}
for pth_file in sorted(folder.glob("*.pth")):
shard.update(torch.load(pth_file, map_location="cpu"))
shards.append(shard)
else:
checkpoint_list = sorted([file for file in input_dir.rglob("*.pth")])
logger.info(f"Loading original Llama weights from {len(checkpoint_list)} files")
shards = [
torch.load(ckpt, map_location="cpu")
for ckpt in checkpoint_list
]
|
||
if len(shards) == 1: | ||
state_dict = shards[0] | ||
else: # sharded |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
], | ||
dim=0, | ||
).reshape(nh * len(shards) * dims_per_head, dim) | ||
|
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
Sounds good, it was mentioned earlier in thread so I figured it was needed. Thanks! |
I believe the peak should be one full weight copy + the single biggest matrix. So only slightly more than one full copy, assuming the DCP save call doesn't allocate much. |
The reason it's nice to have it for HF is because the download speed is much faster. I just tried to query a download for the 405B from Meta and it runs at 2MB/s in download speed where as HF can give you speeds of several GB/s. Are there any alternatives for downloading the 405B weights to test this PR? |
You can run the code HF provides for undoing the additional sharding they've performed, see my link above. |
I converted the full Llama 3.1 series to DCP by using your script. The 405B weights took about one hour from torch to DCP after I converted them from the sharded format back to torch. |
Can you explain how to run with the weights after conversion to DCP @rlrs? I put the DCP weights in a The error triggers here: torchtitan/torchtitan/checkpoint.py Lines 483 to 486 in 1060fea
Error:
|
Oh no, I probably overlooked it because my code doesn't load torchtitan/torchtitan/models/llama/model.py Lines 362 to 369 in ccfc02b
As long as you're not using pipeline parallelism, you should be able to load the checkpoint by setting |
EDIT: After reading #201, it seems we can treat our conversion from torch to DCP as a seed checkpoint, meaning that it needs to be a complete checkpoint if we want to use the features that TorchTitan has to offer. So I would suggest that we import It seems quite easy to pass in the correct args now that you have the model loaded anyways during conversion: from torchtitan.models.llama.model import precompute_freqs_cis
def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# TODO: explain in docs/composability.md why we removed the 2x
# relaxing in our CP enablement PR
self.model_args.max_seq_len,
self.model_args.rope_theta,
) |
Thanks @casper-hansen for helping test things, hopefully it works now! |
@rlrs I tested the changes and the 8B now loads successfully for me. The loss on c4 starts at EDIT: Also just tested Llama 3.1 70B. Starts with loss at |
We can try @jaysonfrancis 's script #640 to do inference and see if it's meaningful. |
@tianyu-l I just tested generation of the torch -> DCP checkpoint of Llama 3.1 8B, seems fine to me.
|
Sounds terrific, any blockers for merging this script? It looks good to me. |
Closes #305.
Just wanted to get this out here quickly.
The script is very simple since the weights are already in the completely correct format, names and everything. All of the complexity is avoided by not using HF, and so I believe that any functionality relating to HF should be located on their side. However, I do have a DCP -> HF export script which might be useful for some people, in case HF does not have/add one.
I'll be happy to add any needed documentation or tests.