Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for QAT + LoRA #1931

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Oct 31, 2024

TODO: write this

Helpful code review commands:

diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py
diff --color recipes/configs/llama2/7B_lora.yaml recipes/configs/llama2/7B_qat_lora.yaml
diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml

Test Plan

Unit tests:

pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py

Manual tests:

export CUDA_VISIBLE_DEVICES=4,5,6,7
export NCCL_SHM_DISABLE=0
LOG_DIR=/home/andrewor/local/logs/tune/qat_lora

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \
    batch_size=8 \
    quantizer.groupsize=32 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics"

tune run quantize --config quantization \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt"] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    tasks=[wikitext] \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

Results:

# Baseline (LoRA only, no QAT)

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6676|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5884|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |11.8741|±  |   N/A|

# LoRA + QAT (new recipe)

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6623|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5826|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |11.6457|±  |   N/A|

Copy link

pytorch-bot bot commented Oct 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1931

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures, 5 Cancelled Jobs

As of commit 1a48a20 with merge base f560cbb (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewor14 andrewor14 marked this pull request as draft October 31, 2024 00:10
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 31, 2024
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will delete these files before land

@andrewor14 andrewor14 force-pushed the try-qat-lora branch 2 times, most recently from e20e891 to d09c71f Compare November 1, 2024 19:38
TODO: write this
# TODO: Expose fake quantize configs from torchao so we can get them
# directly from the quantizer. For now, we hardcode the configs for 8da4w.
# E.g. activation_config = quantizer.get_activation_fake_quantize_config()
# E.g. weight_config = quantizer.get_weight_fake_quantize_config()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in pytorch/ao#1214

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants