Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Settings to make Llama3-8B on 8 GPUs faster #615

Draft
wants to merge 4 commits into
base: gh/awgu/17/base
Choose a base branch
from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Oct 14, 2024

Stack from ghstack (oldest at bottom):

Requires pytorch/pytorch#137922

TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 

Baseline (torch.compile, no AC):

  • ~6690 WPS, 39.2% MFU

With this PR (~13% speedup):

  • ~7570 WPS, 44.3% MFU
Baseline: compile, no AC
[rank0]:2024-10-21 21:33:48,351 - root - INFO - step:  1  loss: 12.2449  memory: 65.62GiB(69.07%)  wps: 878  mfu: 5.14%
[rank0]:2024-10-21 21:33:48,351 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-21 21:33:49,604 - root - INFO - step:  2  loss: 12.0784  memory: 73.28GiB(77.13%)  wps: 6,540  mfu: 38.30%
[rank0]:2024-10-21 21:33:50,829 - root - INFO - step:  3  loss: 11.7474  memory: 73.28GiB(77.13%)  wps: 6,689  mfu: 39.17%
[rank0]:2024-10-21 21:33:52,053 - root - INFO - step:  4  loss: 11.2879  memory: 73.28GiB(77.13%)  wps: 6,698  mfu: 39.23%
[rank0]:2024-10-21 21:33:53,276 - root - INFO - step:  5  loss: 10.7903  memory: 73.28GiB(77.13%)  wps: 6,699  mfu: 39.23%
[rank0]:2024-10-21 21:33:54,499 - root - INFO - step:  6  loss: 10.5275  memory: 73.28GiB(77.13%)  wps: 6,699  mfu: 39.23%
[rank0]:2024-10-21 21:33:55,723 - root - INFO - step:  7  loss: 10.2621  memory: 73.28GiB(77.13%)  wps: 6,697  mfu: 39.22%
[rank0]:2024-10-21 21:33:56,950 - root - INFO - step:  8  loss: 10.3680  memory: 73.28GiB(77.13%)  wps: 6,680  mfu: 39.12%
[rank0]:2024-10-21 21:33:58,180 - root - INFO - step:  9  loss:  9.8565  memory: 73.28GiB(77.13%)  wps: 6,662  mfu: 39.01%
[rank0]:2024-10-21 21:33:59,405 - root - INFO - step: 10  loss:  9.8140  memory: 73.28GiB(77.13%)  wps: 6,691  mfu: 39.18%
[rank0]:2024-10-21 21:34:00,629 - root - INFO - step: 11  loss:  9.5486  memory: 73.28GiB(77.13%)  wps: 6,695  mfu: 39.20%
[rank0]:2024-10-21 21:34:01,854 - root - INFO - step: 12  loss:  9.4141  memory: 73.28GiB(77.13%)  wps: 6,693  mfu: 39.20%
[rank0]:2024-10-21 21:34:03,083 - root - INFO - step: 13  loss:  9.2171  memory: 73.28GiB(77.13%)  wps: 6,668  mfu: 39.05%
[rank0]:2024-10-21 21:34:03,088 - root - WARNING - Dataset c4_test is being re-looped
With this PR
[rank0]:2024-10-21 21:27:03,502 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.)
[rank0]:  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:2024-10-21 21:27:47,325 - root - INFO - step:  1  loss: 12.2214  memory: 63.69GiB(67.04%)  wps: 187  mfu: 1.09%
[rank0]:2024-10-21 21:27:47,325 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-21 21:27:48,940 - root - INFO - step:  2  loss: 12.0538  memory: 70.96GiB(74.69%)  wps: 5,072  mfu: 29.70%
[rank0]:2024-10-21 21:27:50,023 - root - INFO - step:  3  loss: 11.7264  memory: 70.96GiB(74.69%)  wps: 7,567  mfu: 44.31%
[rank0]:2024-10-21 21:27:51,104 - root - INFO - step:  4  loss: 11.2348  memory: 70.96GiB(74.69%)  wps: 7,584  mfu: 44.41%
[rank0]:2024-10-21 21:27:52,189 - root - INFO - step:  5  loss: 10.7867  memory: 70.96GiB(74.69%)  wps: 7,553  mfu: 44.23%
[rank0]:2024-10-21 21:27:53,268 - root - INFO - step:  6  loss: 10.5913  memory: 70.96GiB(74.69%)  wps: 7,590  mfu: 44.45%
[rank0]:2024-10-21 21:27:54,349 - root - INFO - step:  7  loss: 10.3179  memory: 70.96GiB(74.69%)  wps: 7,585  mfu: 44.42%
[rank0]:2024-10-21 21:27:55,434 - root - INFO - step:  8  loss: 10.3510  memory: 70.96GiB(74.69%)  wps: 7,553  mfu: 44.23%
[rank0]:2024-10-21 21:27:56,521 - root - INFO - step:  9  loss: 10.0707  memory: 70.96GiB(74.69%)  wps: 7,538  mfu: 44.14%
[rank0]:2024-10-21 21:27:57,603 - root - INFO - step: 10  loss:  9.8471  memory: 70.96GiB(74.69%)  wps: 7,572  mfu: 44.34%
[rank0]:2024-10-21 21:27:58,686 - root - INFO - step: 11  loss:  9.6122  memory: 70.96GiB(74.69%)  wps: 7,567  mfu: 44.31%
[rank0]:2024-10-21 21:27:59,769 - root - INFO - step: 12  loss:  9.4770  memory: 70.96GiB(74.69%)  wps: 7,567  mfu: 44.31%
[rank0]:2024-10-21 21:28:00,855 - root - INFO - step: 13  loss:  9.5777  memory: 70.96GiB(74.69%)  wps: 7,544  mfu: 44.17%
[rank0]:2024-10-21 21:28:00,861 - root - WARNING - Dataset c4_test is being re-looped

Coordinate descent tuning gives ~2% MFU boost.

awgu added a commit that referenced this pull request Oct 14, 2024
ghstack-source-id: 26228ccad42d4c33cbfff742de7ad95e2a16dcde
Pull Request resolved: #615
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 14, 2024
@awgu awgu changed the title [Not for land] Seettings to make Llama3-8B on 8 GPUs faster [Not for land] Settings to make Llama3-8B on 8 GPUs faster Oct 14, 2024
awgu added a commit that referenced this pull request Oct 14, 2024
ghstack-source-id: c5767bb4f3d7ad3330953ab97b9f06ff5c6917f5
Pull Request resolved: #615
Requires pytorch/pytorch#137922

```
TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 
```

```
[rank0]:2024-10-14 11:58:53,071 - root - INFO - step:  1  loss: 12.2208  memory: 66.44GiB(69.93%)  wps: 882  mfu: 5.17%
[rank0]:2024-10-14 11:58:53,071 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-14 11:58:54,196 - root - INFO - step:  2  loss: 12.0630  memory: 73.96GiB(77.85%)  wps: 7,282  mfu: 42.64%
[rank0]:2024-10-14 11:58:55,322 - root - INFO - step:  3  loss: 11.7272  memory: 73.96GiB(77.85%)  wps: 7,276  mfu: 42.60%
[rank0]:2024-10-14 11:58:56,448 - root - INFO - step:  4  loss: 11.2526  memory: 73.96GiB(77.85%)  wps: 7,280  mfu: 42.63%
[rank0]:2024-10-14 11:58:57,575 - root - INFO - step:  5  loss: 10.7972  memory: 73.96GiB(77.85%)  wps: 7,268  mfu: 42.56%
[rank0]:2024-10-14 11:58:58,699 - root - INFO - step:  6  loss: 10.5048  memory: 73.96GiB(77.85%)  wps: 7,293  mfu: 42.70%
[rank0]:2024-10-14 11:58:59,824 - root - INFO - step:  7  loss: 10.3384  memory: 73.96GiB(77.85%)  wps: 7,285  mfu: 42.66%
[rank0]:2024-10-14 11:59:00,952 - root - INFO - step:  8  loss: 10.3164  memory: 73.96GiB(77.85%)  wps: 7,266  mfu: 42.55%
[rank0]:2024-10-14 11:59:02,083 - root - INFO - step:  9  loss: 10.0995  memory: 73.96GiB(77.85%)  wps: 7,247  mfu: 42.44%
[rank0]:2024-10-14 11:59:03,211 - root - INFO - step: 10  loss:  9.9308  memory: 73.96GiB(77.85%)  wps: 7,264  mfu: 42.54%
[rank0]:2024-10-14 11:59:04,337 - root - INFO - step: 11  loss:  9.5785  memory: 73.96GiB(77.85%)  wps: 7,275  mfu: 42.60%
[rank0]:2024-10-14 11:59:05,465 - root - INFO - step: 12  loss:  9.5265  memory: 73.96GiB(77.85%)  wps: 7,267  mfu: 42.56%
[rank0]:2024-10-14 11:59:06,595 - root - INFO - step: 13  loss:  9.3497  memory: 73.96GiB(77.85%)  wps: 7,252  mfu: 42.47%
[rank0]:2024-10-14 11:59:06,601 - root - WARNING - Dataset c4_test is being re-looped
```

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: e479129225ebeb6d27086ec029d800f0c0b0838c
Pull Request resolved: #615
Requires pytorch/pytorch#137922

```
TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 
```

```
[rank0]:2024-10-21 21:23:32,899 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.)
[rank0]:  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:2024-10-21 21:23:42,336 - root - INFO - step:  1  loss: 12.2799  memory: 63.45GiB(66.79%)  wps: 868  mfu: 5.08%
[rank0]:2024-10-21 21:23:42,336 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-21 21:23:43,539 - root - INFO - step:  2  loss: 12.1023  memory: 70.96GiB(74.69%)  wps: 6,813  mfu: 39.90%
[rank0]:2024-10-21 21:23:44,667 - root - INFO - step:  3  loss: 11.7899  memory: 70.96GiB(74.69%)  wps: 7,263  mfu: 42.53%
[rank0]:2024-10-21 21:23:45,795 - root - INFO - step:  4  loss: 11.3163  memory: 70.96GiB(74.69%)  wps: 7,264  mfu: 42.54%
[rank0]:2024-10-21 21:23:46,923 - root - INFO - step:  5  loss: 10.8908  memory: 70.96GiB(74.69%)  wps: 7,262  mfu: 42.52%
[rank0]:2024-10-21 21:23:48,050 - root - INFO - step:  6  loss: 10.4146  memory: 70.96GiB(74.69%)  wps: 7,275  mfu: 42.60%
[rank0]:2024-10-21 21:23:49,174 - root - INFO - step:  7  loss: 10.1523  memory: 70.96GiB(74.69%)  wps: 7,288  mfu: 42.68%
[rank0]:2024-10-21 21:23:50,306 - root - INFO - step:  8  loss: 10.2847  memory: 70.96GiB(74.69%)  wps: 7,240  mfu: 42.40%
[rank0]:2024-10-21 21:23:51,434 - root - INFO - step:  9  loss: 10.0047  memory: 70.96GiB(74.69%)  wps: 7,263  mfu: 42.53%
[rank0]:2024-10-21 21:23:52,560 - root - INFO - step: 10  loss:  9.9882  memory: 70.96GiB(74.69%)  wps: 7,279  mfu: 42.63%
[rank0]:2024-10-21 21:23:53,685 - root - INFO - step: 11  loss:  9.6261  memory: 70.96GiB(74.69%)  wps: 7,285  mfu: 42.66%
[rank0]:2024-10-21 21:23:54,813 - root - INFO - step: 12  loss:  9.5229  memory: 70.96GiB(74.69%)  wps: 7,265  mfu: 42.54%
[rank0]:2024-10-21 21:23:55,944 - root - INFO - step: 13  loss:  9.4371  memory: 70.96GiB(74.69%)  wps: 7,244  mfu: 42.42%
[rank0]:2024-10-21 21:23:55,950 - root - WARNING - Dataset c4_test is being re-looped
```

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: bdaa49373bb992258483b4a6c5ceb37b826c0d86
Pull Request resolved: #615
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants