Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script to convert pickled Llama weights to DCP #634

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

rlrs
Copy link

@rlrs rlrs commented Oct 19, 2024

Closes #305.

Just wanted to get this out here quickly.
The script is very simple since the weights are already in the completely correct format, names and everything. All of the complexity is avoided by not using HF, and so I believe that any functionality relating to HF should be located on their side. However, I do have a DCP -> HF export script which might be useful for some people, in case HF does not have/add one.

I'll be happy to add any needed documentation or tests.

@facebook-github-bot
Copy link

Hi @rlrs!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 19, 2024
@casper-hansen
Copy link

casper-hansen commented Oct 19, 2024

@rlrs Looking at the structure of the original weights, this script only covers the 8B model.

@rlrs
Copy link
Author

rlrs commented Oct 19, 2024

Ah right, thanks for noticing. Will update that in a bit.

@rlrs
Copy link
Author

rlrs commented Oct 19, 2024

Here we go, the latest commit hopefully works for both 70B and 405B, if I've understood it correctly. You do have to load up all the shards at once with this approach, I'm not sure there's an easy way around it.

Currently testing and making sure that it works with 70B, it takes longer to test it than to write the code... Unfortunately I only have a node with 512 GB RAM so I can't convert or test 405B myself.

I do suspect that someone with more knowledge about DCP could potentially process a shard at a time, or even produce multiple shards for more efficient loading. This seems difficult to do without knowing the parallelization scheme you plan to use though.

Edit: Verified that 70B works - unsure what a good test would be in the repo, if necessary. I loaded up the checkpoint and ensured that loss on a small dataset is low.

@casper-hansen
Copy link

@rlrs Thanks for the update. Once the TorchTitan team has verified this, I will help by converting all the Llama models and uploading them to Huggingface in DCP format (if the team does not do this).

@tianyu-l tianyu-l requested a review from fduwjj October 21, 2024 21:45
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks good in general.

I do suspect that someone with more knowledge about DCP could potentially process a shard at a time, or even produce multiple shards for more efficient loading.

Discussed with @wz337 offline, and it seems it's not easy to have a fundamentally more memory-efficient implementation, unless one happen to have at least #GPUs the weights are originally saved from.

Verified that 70B works - unsure what a good test would be in the repo, if necessary. I loaded up the checkpoint and ensured that loss on a small dataset is low.

I agree it's not clear what a unit test would look like. Then please just create a brief tutorial in docs/checkpoint.md.

scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the script. Some general comment:

What's our target model? 8B? 70B? 405B? The tool is likely to be bound by the time and storage not memory.

For 70B, I think couple hours is a reasonable number: (70 * 4 * 1000 / 100) seconds per a single load (torch.load) or save (DCP.save, torch.save). So 1-2 hours should be reasonable.

For 405B, if you use mmap=True, it is possible to convert 405B with 2 passes, the first pass converts the state_dict and the second pass saves with DCP. However, with 405B, we may end up with super large file that the file system may not support. And this design mean 4x save and load - (405 * 4 * 1000 / 100) * 4 , roughly 18 hours if nothing goes wrong and the tool must be carefully designed.

For 8B and 70B models, it is reasonable to use this tool. For 405B, using TorchTitan trainer to do the conversion maybe a more reasonable approach.

scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
@wz337
Copy link
Contributor

wz337 commented Oct 22, 2024

Some thoughts on the unit test. Open to discussion to see if this is necessary:

Maybe we can use the smallest llama weights (3.2 has a 1B) one. We can:

  1. convert the weights into dcp format
  2. load both state dict (DCP and original Llama weights) in memory
  3. concating the llama weights if needed and compare with DTensor.full_tensor() to verify the weights are identical.

As for the documentation, feel free to make checkpoint.md into different sections for conversion.

@tianyu-l
Copy link
Contributor

tianyu-l commented Oct 22, 2024

Thanks @fegin !

For 70B, I think couple hours is a reasonable number: (70 * 4 * 1000 / 100) seconds per a single load (torch.load) or save (DCP.save, torch.save). So 1-2 hours should be reasonable.

Can I ask where the 1000 and 100 come from? Just for my education.

1GB = 1000MB, 100 means 100MB/s for disk read/write. It may not be too slow or too fast depending on the storage.

For 8B and 70B models, it is reasonable to use this tool. For 405B, using TorchTitan trainer to do the conversion maybe a more reasonable approach.

What is the torchtitan trainer and how do we use it?

I meant changing TorchTitan's CheckpointManager to accept some customized function for the first iteration to load customized checkpoint.

@rlrs
Copy link
Author

rlrs commented Oct 23, 2024

Addressed all the comments on the code now, I think. Will add some information to checkpoint.md in a bit.

I meant changing TorchTitan's CheckpointManager to accept some customized function for the first iteration to load customized checkpoint.

Do you think this is a better approach? If we follow that approach I don't think this script would be needed, but we could probably just change it into a loader for the CheckpointManager.

I think it is a better approach. But this requires some design. We can land this script first.

By the way, I benchmarked memory usage while converting 8B and the script peaks at 16 GB resident size. It seems that deleting weights from shards after use doesn't change anything.

Why 16GB? Isn't the dtype f32? Can we also try mmap option as mentioned in the comment?

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks almost ready to merge.

  1. Can you please try @fegin 's suggestion of using mmap?

Why 16GB? Isn't the dtype f32? Can we also try mmap option as mentioned in the comment?

I'm a bit confused. If 8B in float32 occupies 32GB, wouldn't this approach need 2x which is 64GB? since it loads all the shards and then create a new state_dict, each being a copy of the model.

  1. Also please add a short tutorial in docs/checkpoint.md.

scripts/convert_llama_to_dcp.py Outdated Show resolved Hide resolved
@jaysonfrancis
Copy link

jaysonfrancis commented Oct 24, 2024

Below are my attempts of converting 70B & 405B originals hosted on HF hub.

70B

models--meta-llama--Llama-3.1-70B/snapshots/349b2ddb53ce8f2849a6c168a81980ab25258dac/original

> 1747.32s user 5869.44s system 5888% cpu 2:09.36 total

190k -N  .metadata  
18G -N  __0_0.distcp  
18G -N  __0_1.distcp  
18G -N  __0_2.distcp  
18G -N  __0_3.distcp  
18G -N  __0_4.distcp  
18G -N  __0_5.distcp  
18G -N  __0_6.distcp  
18G -N  __0_7.distcp

405B

models--meta-llama--Llama-3.1-405B/snapshots/b906e4dc842aa489c962f9db26554dcfdde901fe/original/mp16

shards[i][f"layers.{layer}.attention.{wn}.weight"].view(
RuntimeError: shape '[0, 128, 16384]' is invalid for input of size 16777216

The issue is when you have nested shards for 405B (as done in HF Hub), you get num_shards (176) > num_heads (128), which incorrectly sets n_heads_per_shard=0.

I made changes to this script to consolidate each nested folder, resulting in num_shards=16, n_heads_per_shard=8

After processing 105 of 126 layers, I got OOM. (>1.5TB).

I worked around it by deallocating memory within the loop after processing each shard. (peaked at ~840GB)

> 2253.69s user 10059.64s system 1351% cpu 15:10.78 total

300k  .metadata
103G  __0_0.distcp
103G  __0_1.distcp
103G  __0_2.distcp
103G  __0_3.distcp
103G  __0_4.distcp
103G  __0_5.distcp
102G  __0_6.distcp
102G  __0_7.distcp

Python 3.11.10, PyTorch 2.5.0+cu124, Ubuntu 22.04 (2x EPYC 9600 Series w/ 1.5Ti Mem)

@casper-hansen
Copy link

casper-hansen commented Oct 24, 2024

The issue is when you have nested shards for 405B (as done in HF Hub), you get num_shards (176) > num_heads (128), which incorrectly sets n_heads_per_shard=0.

I made changes to this script to consolidate each nested folder, resulting in num_shards=16, n_heads_per_shard=8

After processing 105 of 126 layers, I got OOM. (>1.5TB).

I worked around it by deallocating memory within the loop after processing each shard. (peaked at ~840GB)

@jaysonfrancis Could you add your changes in a comment or as a PR to @rlrs's branch?

@rlrs
Copy link
Author

rlrs commented Oct 24, 2024

Thanks! Looks almost ready to merge.

1. Can you please try @fegin 's suggestion of using `mmap`?

Why 16GB? Isn't the dtype f32? Can we also try mmap option as mentioned in the comment?

I'm a bit confused. If 8B in float32 occupies 32GB, wouldn't this approach need 2x which is 64GB? since it loads all the shards and then create a new state_dict, each being a copy of the model.

2. Also please add a short tutorial in `docs/checkpoint.md`.

The llama3 original weights are published in bf16, it seems. E.g. the pickle at https://huggingface.co/meta-llama/Llama-3.1-8B/tree/main/original has BFloat16Storage.

mmap is already enabled on the script now, by the way. My previous message didn't make much sense because in the un-sharded case nothing is copied.
I'll add in explicit dels which should help with the memory usage as mentioned.

@jaysonfrancis I'd appreciate your fixes for loading 405B from HF, however I am unsure if we should support it. 405B on HF is not stored in the original format, which is why my code does not work for it. See here for an explanation: https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/original/mp16/README.md

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Thank you very much for contributing!

I wonder what's the peak memory usage now if we load the 8B model, with the help of mmap and del.

@tianyu-l
Copy link
Contributor

Thanks @jaysonfrancis !

405B on HF is not stored in the original format, which is why my code does not work for it.

I agree that we probably shouldn't land the adaptation to the special treatment from HF. It is OK to keep a branch or even a PR on it.

torch.load(ckpt, map_location="cpu", weights_only=True, mmap=True)
for ckpt in checkpoint_list
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    # Load shards
    subdirs = [folder for folder in input_dir.iterdir() if folder.is_dir()]
    if subdirs:
        checkpoint_folders = sorted(subdirs)
        logger.info(f"Loading original Llama weights from {len(checkpoint_folders)} folders")

        # Load all .pth files within each folder and treat each folder as a shard
        shards = []
        for folder in checkpoint_folders:
            shard = {}
            for pth_file in sorted(folder.glob("*.pth")):
                shard.update(torch.load(pth_file, map_location="cpu"))
            shards.append(shard)
    else:
        checkpoint_list = sorted([file for file in input_dir.rglob("*.pth")])
        logger.info(f"Loading original Llama weights from {len(checkpoint_list)} files")

        shards = [
            torch.load(ckpt, map_location="cpu")
            for ckpt in checkpoint_list
        ]


if len(shards) == 1:
state_dict = shards[0]
else: # sharded

This comment was marked as resolved.

],
dim=0,
).reshape(nh * len(shards) * dims_per_head, dim)

This comment was marked as resolved.

@jaysonfrancis
Copy link

I agree that we probably shouldn't land the adaptation to the special treatment from HF. It is OK to keep a branch or even a PR on it.

Sounds good, it was mentioned earlier in thread so I figured it was needed. Thanks!

@rlrs
Copy link
Author

rlrs commented Oct 25, 2024

I wonder what's the peak memory usage now if we load the 8B model, with the help of mmap and del.

I believe the peak should be one full weight copy + the single biggest matrix. So only slightly more than one full copy, assuming the DCP save call doesn't allocate much.

@casper-hansen
Copy link

Thanks @jaysonfrancis !

405B on HF is not stored in the original format, which is why my code does not work for it.

I agree that we probably shouldn't land the adaptation to the special treatment from HF. It is OK to keep a branch or even a PR on it.

The reason it's nice to have it for HF is because the download speed is much faster. I just tried to query a download for the 405B from Meta and it runs at 2MB/s in download speed where as HF can give you speeds of several GB/s. Are there any alternatives for downloading the 405B weights to test this PR?

@rlrs
Copy link
Author

rlrs commented Oct 25, 2024

Are there any alternatives for downloading the 405B weights to test this PR?

You can run the code HF provides for undoing the additional sharding they've performed, see my link above.

@casper-hansen
Copy link

You can run the code HF provides for undoing the additional sharding they've performed, see my link above.

I converted the full Llama 3.1 series to DCP by using your script. The 405B weights took about one hour from torch to DCP after I converted them from the sharded format back to torch.

@casper-hansen
Copy link

casper-hansen commented Oct 28, 2024

Can you explain how to run with the weights after conversion to DCP @rlrs? I put the DCP weights in a step-0 folder and pointed TorchTitan to the weights. However, it seems it doesn't work out of the box with TorchTitan because it has a missing key for freqs_cis (Llama 3.1 8B).

The error triggers here:

dcp.load(
states,
checkpoint_id=self._create_checkpoint_id(step),
)

Error:

File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 354, in create_default_local_load_plan
  raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
RuntimeError: Missing key in checkpoint state_dict: model.freqs_cis.

@rlrs
Copy link
Author

rlrs commented Oct 28, 2024

Oh no, I probably overlooked it because my code doesn't load freqs_cis. It shouldn't normally be necessary since it can easily be computed, but I see that torchtitan by default persists it for a couple of reasons:

# TODO persistent should be set to false, since this buffer can be recomputed.
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
# so we need to fix that. (2) if we initialize pipeline-parallel models from
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
# initialized by the checkpoint, or we need to add a separate initializer for
# just the non-persistent buffers that is called after loading checkpoints.
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)

As long as you're not using pipeline parallelism, you should be able to load the checkpoint by setting persistent=False there. However as long as this is set to True by default, we should probably also save freqs_cis to DCP.

@casper-hansen
Copy link

casper-hansen commented Oct 28, 2024

As long as you're not using pipeline parallelism, you should be able to load the checkpoint by setting persistent=False there. However as long as this is set to True by default, we should probably also save freqs_cis to DCP.

EDIT: After reading #201, it seems we can treat our conversion from torch to DCP as a seed checkpoint, meaning that it needs to be a complete checkpoint if we want to use the features that TorchTitan has to offer. So I would suggest that we import precompute_freqs_cis and use it to correctly set the freqs_cis key. Given the bug is in PyTorch and TorchTitan has implemented a specific workaround, I think this is a reasonable solution.

It seems quite easy to pass in the correct args now that you have the model loaded anyways during conversion:

from torchtitan.models.llama.model import precompute_freqs_cis

    def _precompute_freqs_cis(self) -> torch.Tensor:
        return precompute_freqs_cis(
            self.model_args.dim // self.model_args.n_heads,
            # Need to compute until at least the max token limit for generation
            # TODO: explain in docs/composability.md why we removed the 2x
            # relaxing in our CP enablement PR
            self.model_args.max_seq_len,
            self.model_args.rope_theta,
        )

scripts/convert_llama_to_dcp.py Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Show resolved Hide resolved
scripts/convert_llama_to_dcp.py Show resolved Hide resolved
@rlrs
Copy link
Author

rlrs commented Oct 30, 2024

Thanks @casper-hansen for helping test things, hopefully it works now!

@casper-hansen
Copy link

casper-hansen commented Oct 30, 2024

@rlrs I tested the changes and the 8B now loads successfully for me. The loss on c4 starts at 2.4576 when loading the pretrained weights compared to 12.2599 with newly initialised weights. At the moment, this is the only way I can verify the weights are working as intended apart from running evaluations.

EDIT: Also just tested Llama 3.1 70B. Starts with loss at 2.0512.

@tianyu-l
Copy link
Contributor

@casper-hansen

At the moment, this is the only way I can verify the weights are working as intended apart from running evaluations.

We can try @jaysonfrancis 's script #640 to do inference and see if it's meaningful.

@casper-hansen
Copy link

casper-hansen commented Oct 31, 2024

@tianyu-l I just tested generation of the torch -> DCP checkpoint of Llama 3.1 8B, seems fine to me.

  • Input: "<|begin_of_text|>A curious person is"
  • Output: " always ready with questions to ask and especially on this quest of our spiritual journey. This time\nwe shall start with a big question of importance and relevance .\nA"

@tianyu-l
Copy link
Contributor

Sounds terrific, any blockers for merging this script? It looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

reload existing llama checkpoints
7 participants