-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Not for land] Settings to make Llama3-8B on 8 GPUs faster #615
Draft
awgu
wants to merge
4
commits into
gh/awgu/17/base
Choose a base branch
from
gh/awgu/17/head
base: gh/awgu/17/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ghstack-poisoned]
awgu
added a commit
that referenced
this pull request
Oct 14, 2024
ghstack-source-id: 26228ccad42d4c33cbfff742de7ad95e2a16dcde Pull Request resolved: #615
facebook-github-bot
added
the
CLA Signed
This label is managed by the Meta Open Source bot.
label
Oct 14, 2024
awgu
changed the title
[Not for land] Seettings to make Llama3-8B on 8 GPUs faster
[Not for land] Settings to make Llama3-8B on 8 GPUs faster
Oct 14, 2024
[ghstack-poisoned]
awgu
added a commit
that referenced
this pull request
Oct 14, 2024
ghstack-source-id: c5767bb4f3d7ad3330953ab97b9f06ff5c6917f5 Pull Request resolved: #615
Requires pytorch/pytorch#137922 ``` TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh ``` ``` [rank0]:2024-10-14 11:58:53,071 - root - INFO - step: 1 loss: 12.2208 memory: 66.44GiB(69.93%) wps: 882 mfu: 5.17% [rank0]:2024-10-14 11:58:53,071 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-10-14 11:58:54,196 - root - INFO - step: 2 loss: 12.0630 memory: 73.96GiB(77.85%) wps: 7,282 mfu: 42.64% [rank0]:2024-10-14 11:58:55,322 - root - INFO - step: 3 loss: 11.7272 memory: 73.96GiB(77.85%) wps: 7,276 mfu: 42.60% [rank0]:2024-10-14 11:58:56,448 - root - INFO - step: 4 loss: 11.2526 memory: 73.96GiB(77.85%) wps: 7,280 mfu: 42.63% [rank0]:2024-10-14 11:58:57,575 - root - INFO - step: 5 loss: 10.7972 memory: 73.96GiB(77.85%) wps: 7,268 mfu: 42.56% [rank0]:2024-10-14 11:58:58,699 - root - INFO - step: 6 loss: 10.5048 memory: 73.96GiB(77.85%) wps: 7,293 mfu: 42.70% [rank0]:2024-10-14 11:58:59,824 - root - INFO - step: 7 loss: 10.3384 memory: 73.96GiB(77.85%) wps: 7,285 mfu: 42.66% [rank0]:2024-10-14 11:59:00,952 - root - INFO - step: 8 loss: 10.3164 memory: 73.96GiB(77.85%) wps: 7,266 mfu: 42.55% [rank0]:2024-10-14 11:59:02,083 - root - INFO - step: 9 loss: 10.0995 memory: 73.96GiB(77.85%) wps: 7,247 mfu: 42.44% [rank0]:2024-10-14 11:59:03,211 - root - INFO - step: 10 loss: 9.9308 memory: 73.96GiB(77.85%) wps: 7,264 mfu: 42.54% [rank0]:2024-10-14 11:59:04,337 - root - INFO - step: 11 loss: 9.5785 memory: 73.96GiB(77.85%) wps: 7,275 mfu: 42.60% [rank0]:2024-10-14 11:59:05,465 - root - INFO - step: 12 loss: 9.5265 memory: 73.96GiB(77.85%) wps: 7,267 mfu: 42.56% [rank0]:2024-10-14 11:59:06,595 - root - INFO - step: 13 loss: 9.3497 memory: 73.96GiB(77.85%) wps: 7,252 mfu: 42.47% [rank0]:2024-10-14 11:59:06,601 - root - WARNING - Dataset c4_test is being re-looped ``` [ghstack-poisoned]
awgu
added a commit
that referenced
this pull request
Oct 22, 2024
ghstack-source-id: e479129225ebeb6d27086ec029d800f0c0b0838c Pull Request resolved: #615
Requires pytorch/pytorch#137922 ``` TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh ``` ``` [rank0]:2024-10-21 21:23:32,899 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.) [rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]:2024-10-21 21:23:42,336 - root - INFO - step: 1 loss: 12.2799 memory: 63.45GiB(66.79%) wps: 868 mfu: 5.08% [rank0]:2024-10-21 21:23:42,336 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-10-21 21:23:43,539 - root - INFO - step: 2 loss: 12.1023 memory: 70.96GiB(74.69%) wps: 6,813 mfu: 39.90% [rank0]:2024-10-21 21:23:44,667 - root - INFO - step: 3 loss: 11.7899 memory: 70.96GiB(74.69%) wps: 7,263 mfu: 42.53% [rank0]:2024-10-21 21:23:45,795 - root - INFO - step: 4 loss: 11.3163 memory: 70.96GiB(74.69%) wps: 7,264 mfu: 42.54% [rank0]:2024-10-21 21:23:46,923 - root - INFO - step: 5 loss: 10.8908 memory: 70.96GiB(74.69%) wps: 7,262 mfu: 42.52% [rank0]:2024-10-21 21:23:48,050 - root - INFO - step: 6 loss: 10.4146 memory: 70.96GiB(74.69%) wps: 7,275 mfu: 42.60% [rank0]:2024-10-21 21:23:49,174 - root - INFO - step: 7 loss: 10.1523 memory: 70.96GiB(74.69%) wps: 7,288 mfu: 42.68% [rank0]:2024-10-21 21:23:50,306 - root - INFO - step: 8 loss: 10.2847 memory: 70.96GiB(74.69%) wps: 7,240 mfu: 42.40% [rank0]:2024-10-21 21:23:51,434 - root - INFO - step: 9 loss: 10.0047 memory: 70.96GiB(74.69%) wps: 7,263 mfu: 42.53% [rank0]:2024-10-21 21:23:52,560 - root - INFO - step: 10 loss: 9.9882 memory: 70.96GiB(74.69%) wps: 7,279 mfu: 42.63% [rank0]:2024-10-21 21:23:53,685 - root - INFO - step: 11 loss: 9.6261 memory: 70.96GiB(74.69%) wps: 7,285 mfu: 42.66% [rank0]:2024-10-21 21:23:54,813 - root - INFO - step: 12 loss: 9.5229 memory: 70.96GiB(74.69%) wps: 7,265 mfu: 42.54% [rank0]:2024-10-21 21:23:55,944 - root - INFO - step: 13 loss: 9.4371 memory: 70.96GiB(74.69%) wps: 7,244 mfu: 42.42% [rank0]:2024-10-21 21:23:55,950 - root - WARNING - Dataset c4_test is being re-looped ``` [ghstack-poisoned]
awgu
added a commit
that referenced
this pull request
Oct 22, 2024
ghstack-source-id: bdaa49373bb992258483b4a6c5ceb37b826c0d86 Pull Request resolved: #615
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Requires pytorch/pytorch#137922
Baseline (
torch.compile
, no AC):With this PR (~13% speedup):
Baseline: compile, no AC
With this PR
Coordinate descent tuning gives ~2% MFU boost.