Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU as a backend #1826

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

noemotiovon
Copy link

@noemotiovon noemotiovon commented Oct 14, 2024

What does this PR do?

Overview

🚀This PR enables the users of torhtune to leverage the Ascend NPU for better performance in inferencing when GPU device is not available.
This PR primarily addresses the initial refactoring of device-independent code. In upcoming changes, we’ll focus on further adjustments, using NPU as an example to refine each recipe and complete the remaining device-independent modifications. For now, this PR only touches on recipe lora_finetune_single_device and full_finetune_single_device.

For more details, see: [#1797].

Environment

  • OS: ubuntu 20.04
  • NPU: Atlas 300T A2
  • CANN: 8.0.RC2
  • torch-npu: 2.4.0 rc1
  • torch: 2.4.0

Note

To properly install CANN, see [here] for more details.

The version of torch-npu should match that of torch, see [here] for more details.

In addition, torch_npu has a pre-release version, 2.4.0 RC1, which is also the basis for this test. For more information, please visit [here].

Examples

To start with, the library torch_npu should be correctly installed and imported. Part of the codes are showed below:

torchtune/utils/_device_support.py:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

def is_torch_npu_available():
    try:
        import torch_npu # noqa: F401
    except ImportError:
        return False
    return torch.npu.is_available()

Plus, there are some other places of the codes might be adjusted, which won't be too much.

Feel free to leave comments to guide me in further improvements 😊.

Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1826

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Cancelled Jobs

As of commit 99a6dd8 with merge base e99b890 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
@noemotiovon noemotiovon marked this pull request as draft October 14, 2024 10:21
torchtune/utils/_device_support.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/quantize.py Outdated Show resolved Hide resolved
@noemotiovon noemotiovon marked this pull request as ready for review October 21, 2024 10:49
@noemotiovon
Copy link
Author

Hi @ebsmothers, @RdoubleA:

I hope you’re doing well! Could you please help me review my code? I would really appreciate it if you could take a look and share any feedback or suggestions. Thank you so much in advance for your time and support! 😊

Best regards

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @noemotiovon thanks for the PR! And apologies for the delay in getting to the review here. A couple other questions I have that don't really fit neatly anywhere inline:

  1. Do we expect compile to work? If so, we should test that. If not, we could raise an error
  2. Do we expect quant-related APIs (e.g. QLoRA or QAT) from torchao to work? Same as point 1: if so we should test or possibly raise an error
  3. PyTorch has now released 2.5 as stable. In general we do not claim to support anything but the latest stable release of PyTorch -- do you know the contract on torch_npu releases here?

torchtune/training/precision.py Outdated Show resolved Hide resolved
recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
recipes/knowledge_distillation_single_device.py Outdated Show resolved Hide resolved
tests/torchtune/utils/test_device.py Show resolved Hide resolved
torchtune/training/_activation_offloading.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Show resolved Hide resolved
@elfisworking
Copy link

elfisworking commented Oct 22, 2024

distributed training seems to have problems e.g qat_distributed @noemotiovon
function torchtune/training/_distributed.py/load_from_full_model_state_dict
sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error.
Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

@noemotiovon
Copy link
Author

distributed training seems to have problems e.g qat_distributed @noemotiovon function torchtune/training/_distributed.py/load_from_full_model_state_dict sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error. Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

I would be very happy to! I will contact you via email.

@elfisworking
Copy link

distributed training seems to have problems e.g qat_distributed @noemotiovon function torchtune/training/_distributed.py/load_from_full_model_state_dict sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error. Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

I would be very happy to! I will contact you via email.

@noemotiovon through 126 email thanks. Looking forward to your email.

torchtune/utils/_device.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Outdated Show resolved Hide resolved
@noemotiovon
Copy link
Author

Basic Usage Test

A single-device fine-tuning process was performed on the Llama 3.1 8B model using the LoRA (Low-Rank Adaptation) technique.

  • Recipe: lora_finetune_single_device

  • Model: Meta-Llama-3.1-8B-Instruct

  • Config:

    # Config for single device LoRA finetuning in lora_finetune_single_device.py
    # using a Llama3.1 8B Instruct model
    #
    # This config assumes that you've run the following command before launching
    # this run:
    #   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
    #
    # To launch on a single device, run the following command from root:
    #   tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
    #
    # You can add specific overrides through the command line. For example
    # to override the checkpointer directory while launching training
    # you can run:
    #   tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
    #
    # This config works only for training on single device.
    
    
    # Model Arguments
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      lora_attn_modules: ['q_proj', 'v_proj']
      apply_lora_to_mlp: False
      apply_lora_to_output: False
      lora_rank: 8
      lora_alpha: 16
      lora_dropout: 0.0
    
    # Tokenizer
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
      max_seq_len: null
    
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files: [
        model-00001-of-00004.safetensors,
        model-00002-of-00004.safetensors,
        model-00003-of-00004.safetensors,
        model-00004-of-00004.safetensors
      ]
      recipe_checkpoint: null
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      model_type: LLAMA3
    resume_from_checkpoint: False
    save_adapter_weights_only: False
    
    # Dataset and Sampler
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    seed: null
    shuffle: True
    batch_size: 30
    
    # Optimizer and Scheduler
    optimizer:
      _component_: torch.optim.AdamW
      fused: False
      weight_decay: 0.01
      lr: 3e-4
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    
    # Training
    epochs: 1
    max_steps_per_epoch: null
    gradient_accumulation_steps: 64
    compile: False
    
    # Logging
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: ${output_dir}
    log_every_n_steps: 1
    log_peak_memory_stats: False
    
    # Environment
    device: npu
    dtype: bf16
    
    # Activations Memory
    enable_activation_checkpointing: True
    enable_activation_offloading: False
    
    # Profiler (disabled)
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      enabled: False
    
      #Output directory of trace artifacts
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    
      #`torch.profiler.ProfilerActivity` types to trace
      cpu: True
      cuda: True
    
      #trace options passed to `torch.profiler.profile`
      profile_memory: False
      with_stack: False
      record_shapes: True
      with_flops: False
    
      # `torch.profiler.schedule` options:
      # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
      wait_steps: 5
      warmup_steps: 5
      active_steps: 2
      num_cycles: 1
  • Logs:

    INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 30
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files:
      - model-00001-of-00004.safetensors
      - model-00002-of-00004.safetensors
      - model-00003-of-00004.safetensors
      - model-00004-of-00004.safetensors
      model_type: LLAMA3
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    enable_activation_offloading: false
    epochs: 1
    gradient_accumulation_steps: 64
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      apply_lora_to_mlp: false
      apply_lora_to_output: false
      lora_alpha: 16
      lora_attn_modules:
      - q_proj
      - v_proj
      lora_dropout: 0.0
      lora_rank: 8
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 0.0003
      weight_decay: 0.01
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      active_steps: 2
      cpu: true
      cuda: true
      enabled: false
      num_cycles: 1
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
      profile_memory: false
      record_shapes: true
      wait_steps: 5
      warmup_steps: 5
      with_flops: false
      with_stack: false
    resume_from_checkpoint: false
    save_adapter_weights_only: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      max_seq_len: null
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 4173222699. Local seed is seed + rank = 4173222699 + 0
    Writing logs to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test/log_1728874769.txt
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 16.98 GiB
            NPU peak memory reserved: 17.00 GiB
            NPU peak memory active: 16.98 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer and loss are initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                                            | 0/26 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
      4%|█████▌                                                                                                                                           | 1/26 [07:09<2:59:1|1|Loss: 1.7533539533615112:   4%|████▍                                                                                                              | 1/26 [07:09<2:59:1|1|Loss: 1.7533539533615112:   8%|████████▊                                                                                                          | 2/26 [14:09<2:49:1|2|Loss: 1.7825285196304321:   8%|████████▊                                                                                                          | 2/26 [14:09<2:49:1|2|Loss: 1.7825285196304321:  12%|█████████████▎                                                                                                     | 3/26 [21:17<2:43:1|3|Loss: 1.7610299587249756:  12%|█████████████▎                                                                                                     | 3/26 [21:17<2:43:1|3|Loss: 1.7610299587249756:  15%|█████████████████▋                                                                                                 | 4/26 [28:28<2:36:1|4|Loss: 1.7874119281768799:  15%|█████████████████▋                                                                                                 | 4/26 [28:28<2:36:1|4|Loss: 1.7874119281768799:  19%|██████████████████████                                                                                             | 5/26 [35:36<2:29:1|5|Loss: 1.7903798818588257:  19%|██████████████████████                                                                                             | 5/26 [35:36<2:29:1|5|Loss: 1.7903798818588257:  23%|██████████████████████████▌                                                                                        | 6/26 [42:52<2:23:1|6|Loss: 1.776786208152771:  23%|██████████████████████████▊                                                                                         | 6/26 [42:52<2:23:1|6|Loss: 1.776786208152771:  27%|███████████████████████████████▏                                                                                    | 7/26 [49:45<2:14:1|7|Loss: 1.7698196172714233:  27%|██████████████████████████████▉                                                                                    | 7/26 [49:45<2:14:29, 424.69s/it]
     *  History restored 
    
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$  
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$ 
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$ tune run lora_finetune_single_device --config my_custom_config.yaml
    INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 30
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files:
      - model-00001-of-00004.safetensors
      - model-00002-of-00004.safetensors
      - model-00003-of-00004.safetensors
      - model-00004-of-00004.safetensors
      model_type: LLAMA3
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    enable_activation_offloading: false
    epochs: 1
    gradient_accumulation_steps: 64
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      apply_lora_to_mlp: false
      apply_lora_to_output: false
      lora_alpha: 16
      lora_attn_modules:
      - q_proj
      - v_proj
      lora_dropout: 0.0
      lora_rank: 8
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 0.0003
      weight_decay: 0.01
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      active_steps: 2
      cpu: true
      cuda: true
      enabled: false
      num_cycles: 1
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
      profile_memory: false
      record_shapes: true
      wait_steps: 5
      warmup_steps: 5
      with_flops: false
      with_stack: false
    resume_from_checkpoint: false
    save_adapter_weights_only: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      max_seq_len: null
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 1031355438. Local seed is seed + rank = 1031355438 + 0
    Writing logs to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test/log_1728878132.txt
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 16.98 GiB
            NPU peak memory reserved: 17.00 GiB
            NPU peak memory active: 16.98 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer and loss are initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                             | 0/26 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
    1|26|Loss: 1.427944302558899: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [3:04:04<00:00, 418.75s/it]INFO:torchtune.utils._logging:Starting checkpoint save...
    INFO:torchtune.utils._logging:Model checkpoint of size 4.98 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0001_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0002_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 4.92 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0003_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 1.17 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0004_0.pt
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.01 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_0.pt
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.01 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_model.bin
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_config.json
    INFO:torchtune.utils._logging:Saving final epoch checkpoint.
    INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
    INFO:torchtune.utils._logging:Checkpoint saved in 65.93 seconds.
    1|26|Loss: 1.427944302558899: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [3:11:40<00:00, 442.34s/it]
    
  • Result: The test results demonstrate the successful completion of a single-device LoRA fine-tuning process on the Llama 3.1 8B model. The configuration included a batch size of 30, gradient accumulation over 64 steps, and one epoch of training on an NPU device using the bf16 data type. Activation checkpointing was enabled, and LoRA fine-tuning was applied to attention modules. The process utilized AdamW as the optimizer with a learning rate of 0.0003 and a cosine learning rate scheduler.

@noemotiovon
Copy link
Author

Basic Usage Test

A single-device full fine-tuning process was performed on the Qwen2 0.5B model using the LoRA (Low-Rank Adaptation) technique.

  • Recipe: full_finetune_single_device

  • Model: Qwen2-0.5B-Instruct

  • Config:

    # Tokenizer
    tokenizer:
      _component_: torchtune.models.qwen2.qwen2_tokenizer
      path: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/vocab.json
      merges_file: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/merges.txt
      max_seq_len: null
    
    # Dataset
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    seed: null
    shuffle: True
    
    # Model Arguments
    model:
      _component_: torchtune.models.qwen2.qwen2_0_5b
    
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct
      checkpoint_files: [
        model.safetensors
      ]
      recipe_checkpoint: null
      output_dir: /home/lcg/tmp/torchtune/Qwen2-0.5B-Instruct-finetune
      model_type: QWEN2
    resume_from_checkpoint: False
    
    # Fine-tuning arguments
    batch_size: 15
    epochs: 1
    optimizer:
      _component_: torch.optim.AdamW
      fused: False
      lr: 2e-5
    
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    optimizer_in_bwd: False
    
    max_steps_per_epoch: null
    gradient_accumulation_steps: 8
    compile: False
    
    # Training environment
    device: npu
    
    # Memory management
    enable_activation_checkpointing: True
    
    # Reduced precision
    dtype: bf16
    
    # Logging
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: ${output_dir}
    output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
    log_every_n_steps: 1
    log_peak_memory_stats: False
  • Logs:

    NFO:torchtune.utils._logging:Running FullFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 15
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct
      checkpoint_files:
      - model.safetensors
      model_type: QWEN2
      output_dir: /home/lcg/tmp/torchtune/Qwen2-0.5B-Instruct-finetune
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    epochs: 1
    gradient_accumulation_steps: 8
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /tmp/Qwen2-0.5B-Instruct-finetune
    model:
      _component_: torchtune.models.qwen2.qwen2_0_5b
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 2.0e-05
    optimizer_in_bwd: false
    output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
    resume_from_checkpoint: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.qwen2.qwen2_tokenizer
      max_seq_len: null
      merges_file: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/merges.txt
      path: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/vocab.json
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 3364767838. Local seed is seed + rank = 3364767838 + 0
    Writing logs to /tmp/Qwen2-0.5B-Instruct-finetune/log_1729914193.txt
    /home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
      if self.device.type != 'cpu':
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 1.55 GiB
            NPU peak memory reserved: 1.61 GiB
            NPU peak memory active: 1.55 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer is initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:No learning rate scheduler configured. Using constant learning rate.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                                                    | 0/431 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
    1|431|Loss: 1.136042833328247: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 431/431 [31:03<00:00,  4.45s/it]INFO:torchtune.utils._logging:Model checkpoint of size 0.99 GB saved to /home/lcg/tmp/torchtune/Qwen2-0.5B-Instruct-finetune/hf_model_0001_0.pt
    INFO:torchtune.utils._logging:Saving final epoch checkpoint.
    INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
    1|431|Loss: 1.136042833328247: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 431/431 [31:07<00:00,  4.33s/it]

@noemotiovon
Copy link
Author

noemotiovon commented Oct 26, 2024

Hi @noemotiovon thanks for the PR! And apologies for the delay in getting to the review here. A couple other questions I have that don't really fit neatly anywhere inline:

  1. Do we expect compile to work? If so, we should test that. If not, we could raise an error
  2. Do we expect quant-related APIs (e.g. QLoRA or QAT) from torchao to work? Same as point 1: if so we should test or possibly raise an error
  3. PyTorch has now released 2.5 as stable. In general we do not claim to support anything but the latest stable release of PyTorch -- do you know the contract on torch_npu releases here?

Hi @ebsmothers, Thank you very much for reviewing my code! ☺️ I’ve made the suggested changes, and the goal of this PR is to accomplish the device-independent modifications for torchtune, using NPU as an example. This will involve adapting all recipes to ultimately make torchtune device-independent, with this PR specifically covering full_finetune_single_device and lora_finetune_single_device. Regarding the third point, torch-npu will release the 2.5.0 RC version on November 7, and I’ll be optimizing the code based on PyTorch’s new features as well! Hope you have a fantastic day ahead!

@noemotiovon
Copy link
Author

Hi @ebsmothers, could you please take a moment to review the code ☺️ ? This update currently supports full_finetune_single_device and lora_finetune_single_device on NPU, and we’ll be adding support for additional recipes in the future. I really appreciate your help—thank you!
Best regards

@ebsmothers
Copy link
Contributor

Hi @noemotiovon sorry for the delay! I will take a look tomorrow if that's alright. Until then I'll tag @RdoubleA and @joecummings in case either of them gets a minute to take a look

@noemotiovon
Copy link
Author

Hi @ebsmothers, when you have a moment, could you take a quick look at the recent changes I made? Your feedback would be greatly appreciated. Thank you!

@@ -430,7 +435,7 @@ def _setup_model(

log.info(f"Model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
if self._device.type in DeviceSupport.get_cuda_like_device_types():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's just me but I find this to be a bit confusing.. can we just scrap this method and use if self._device.type != "cpu" in these places instead? I know it may not be as general but I think this is extra indirection that we don't really need

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you’re right! This encapsulation seems redundant at the moment, as cuda-like devices currently appear to be simply non-CPU devices.

@@ -45,11 +45,11 @@ def set_activation_checkpointing(

def cleanup_before_training() -> None:
"""
Call gc collect, empty CUDA cache, and reset peak memory stats.
Call gc collect, empty CUDA-like cache, and reset peak memory stats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think "device" should be sufficient here

Suggested change
Call gc collect, empty CUDA-like cache, and reset peak memory stats.
Call gc collect, empty device cache, and reset peak memory stats.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion! I’ll make the descriptions more appropriate.

@@ -50,6 +52,7 @@ def verify_bf16_support() -> bool:
- CUDA compute capability >= 8
- NCCL is available and version >= 2.10
- MPS is available and torch was built with MPS
- NPU is available and supports bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a bit redundant. Do we know the exact requirements for bf16 support on NPUs?

return False


logger = get_logger("DEBUG")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this at the top of the file (i.e. just after imports but before function definitions)? That's where I'd expect to see it

Comment on lines 63 to 64
device_type = get_device_support().device_type
device_name = get_device_support().device_name
Copy link
Contributor

@ebsmothers ebsmothers Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can just call get_device_support() once

# Ensure device index matches assigned index when distributed training
if device.index != local_rank:
raise RuntimeError(
f"You can't specify a device index when using distributed training. \
Device specified is {device} but was assigned cuda:{local_rank}"
Device specified is {device} but was assigned cuda device:{local_rank}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this log message was not super clear before but this is also unclear. Maybe something like

Suggested change
Device specified is {device} but was assigned cuda device:{local_rank}"
Device specified is {device} but local rank is {local_rank}"

(assuming that NPU devices also contain rank in their string representation?)

Btw a higher-level question here: it was mentioned in a previous comment that there were issues running some of the distributed training scripts on NPU. Did that all get sorted out?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, NPU devices also contain.
There are still some issues with adapting NPU for distributed scripts. I’m working on it and will include the updates in a future PR. Thanks for the reminder!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we probably don't even need to compare device index against device count for NPU, right? (At least not until those changes land.) Though I think it's fine to leave this as is since handling NPU separately here may be messier

1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu").
2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU").
3. `communication_backend` (str): Specifies the backend used for communication on this device (e.g., "gloo", "nccl", "hccl").
4. `cuda_like` (bool): Indicates whether the device is CUDA-like or not.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in another comment but I'm not sure we 100% need this field. At the very least I find its naming a bit unclear

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I’ll remove this seemingly redundant attribute!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @noemotiovon for your patience! I left a handful more comments, please let me know if anything is unclear

@noemotiovon
Copy link
Author

Thanks @noemotiovon for your patience! I left a handful more comments, please let me know if anything is unclear

Thank you for your review! Your feedback is very clear, and I will make the necessary code changes as soon as possible based on your suggestions. ☺️

@noemotiovon
Copy link
Author

Hi @ebsmothers, I’ve made the code changes based on your suggestions; could you please review it again? ☺️

Additionally:

  1. For NPU devices, it currently checks for bf16 support based only on the device model, encapsulated in the torch.npu.is_bf16_supported() method.
  2. Support for distributed functionality is still being debugged and will be gradually integrated into another PR.

Best regards

return DeviceSupport.from_type(device_type)


def get_torch_device() -> any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think get_torch_device is not descriptive enough and is too similar to our existing get_device method. Given the name I would expect this to return a torch.device, but it's really returning a module/namespace, right? In that case maybe we could call it get_torch_device_namespace or something?

"""Return the corresponding torch attribute based on the device type string.

Returns:
module: The corresponding torch module, or torch.cuda if not found.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda related to my above comment (specifically it's why I suggested namespace in the function name instead of module).. I think module is a pretty overloaded term in PyTorch, and when I see this I think of nn.Module. Even though you're using it correctly, maybe we can say something like this instead to mitigate any potential confusion?

Suggested change
module: The corresponding torch module, or torch.cuda if not found.
module: The corresponding torch device namespace, or torch.cuda if not found.

return getattr(torch, device_type)
except AttributeError:
logger.warning(
f"Device Module '{device_type}' not found in torch, try to load torch.cuda."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here

Suggested change
f"Device Module '{device_type}' not found in torch, try to load torch.cuda."
f"Device namespace '{device_type}' not found in torch, try to load torch.cuda."

# Ensure device index matches assigned index when distributed training
if device.index != local_rank:
raise RuntimeError(
f"You can't specify a device index when using distributed training. \
Device specified is {device} but was assigned cuda:{local_rank}"
Device specified is {device} but was assigned cuda device:{local_rank}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we probably don't even need to compare device index against device count for NPU, right? (At least not until those changes land.) Though I think it's fine to leave this as is since handling NPU separately here may be messier

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @noemotiovon for the updates! I left a couple more comments but I think this is pretty close now. It looks like a unit test is failing in CI though, can you take a look? Happy to provide any debugging pointers if you need

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants