From 29e7250c8f481464d0cb4a931c9547dbed8c76bd Mon Sep 17 00:00:00 2001 From: Sze Wai Yuen Date: Fri, 5 Jul 2024 21:33:14 +0000 Subject: [PATCH] blog: LoRA deepdive Summary: Test Plan: --- blog/llm-finetuning-4/.detignore | 2 + blog/llm-finetuning-4/.gitignore | 5 + blog/llm-finetuning-4/README.md | 71 +++++++ blog/llm-finetuning-4/chat_format.py | 67 ++++++ blog/llm-finetuning-4/dataset_utils.py | 69 +++++++ blog/llm-finetuning-4/deepspeed.yaml | 36 ++++ .../ds_configs/ds_config_stage_1.json | 48 +++++ .../ds_configs/ds_config_stage_2.json | 48 +++++ .../ds_config_stage_2_cpu_offload.json | 52 +++++ .../ds_configs/ds_config_stage_3.json | 47 +++++ blog/llm-finetuning-4/finetune.py | 195 ++++++++++++++++++ blog/llm-finetuning-4/inference.py | 61 ++++++ blog/llm-finetuning-4/lora.yaml | 50 +++++ blog/llm-finetuning-4/requirements.txt | 8 + blog/llm-finetuning-4/startup-hook.sh | 3 + blog/llm-finetuning-4/validate_tokenizer.py | 126 +++++++++++ 16 files changed, 888 insertions(+) create mode 100644 blog/llm-finetuning-4/.detignore create mode 100644 blog/llm-finetuning-4/.gitignore create mode 100644 blog/llm-finetuning-4/README.md create mode 100644 blog/llm-finetuning-4/chat_format.py create mode 100644 blog/llm-finetuning-4/dataset_utils.py create mode 100644 blog/llm-finetuning-4/deepspeed.yaml create mode 100644 blog/llm-finetuning-4/ds_configs/ds_config_stage_1.json create mode 100644 blog/llm-finetuning-4/ds_configs/ds_config_stage_2.json create mode 100644 blog/llm-finetuning-4/ds_configs/ds_config_stage_2_cpu_offload.json create mode 100644 blog/llm-finetuning-4/ds_configs/ds_config_stage_3.json create mode 100644 blog/llm-finetuning-4/finetune.py create mode 100644 blog/llm-finetuning-4/inference.py create mode 100644 blog/llm-finetuning-4/lora.yaml create mode 100644 blog/llm-finetuning-4/requirements.txt create mode 100644 blog/llm-finetuning-4/startup-hook.sh create mode 100644 blog/llm-finetuning-4/validate_tokenizer.py diff --git a/blog/llm-finetuning-4/.detignore b/blog/llm-finetuning-4/.detignore new file mode 100644 index 0000000..5e741f0 --- /dev/null +++ b/blog/llm-finetuning-4/.detignore @@ -0,0 +1,2 @@ +text-to-sql* +checkpoints \ No newline at end of file diff --git a/blog/llm-finetuning-4/.gitignore b/blog/llm-finetuning-4/.gitignore new file mode 100644 index 0000000..d3f89f5 --- /dev/null +++ b/blog/llm-finetuning-4/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +.DS_STORE +text-to-sql* +checkpoints +*.png \ No newline at end of file diff --git a/blog/llm-finetuning-4/README.md b/blog/llm-finetuning-4/README.md new file mode 100644 index 0000000..27a1c3c --- /dev/null +++ b/blog/llm-finetuning-4/README.md @@ -0,0 +1,71 @@ +# Finetuning Mistral-7B using LoRA and DeepSpeed + +In this demo, we finetune [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) using [LoRA](https://arxiv.org/abs/2106.09685) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). We ran LoRA on two 80 GB A100 GPUs, and DeepSpeed on two, four, and eight 80 GB A100 GPUs. + +To get started, first install Determined on your local machine: +```bash +pip install determined +``` + +Then finetune with LoRA: +```bash +det e create lora.yaml . +``` + +Or finetune with DeepSpeed: +```bash +det e create deepspeed.yaml . +``` + +You can view the actual training code in `finetune.py`. + + + + +## Configuration + +Change configuration options in `lora.yaml` or `deepspeed.yaml`. Some important options are: +- `slots_per_trial`: the number of GPUs to use. +- `dataset_subset`: the difficulty subset to train on. +- `per_device_train_batch_size`: the batch size per GPU. + +The results in [our blog post](https://www.determined.ai/blog/llm-finetuning-2) were obtained using `per_device_train_batch_size: 1` and `per_device_eval_batch_size: 4` + + +DeepSpeed configuration files are in the `ds_configs` folder. + +## Testing + +Test your model's generation capabilities: + +```bash +python inference.py --exp_id --dataset_subset +``` + +Where +- `` is the id of your finetuning experiment in the Determined UI. +- `` is one of "easy", "medium", or "hard". + +If you're testing a LoRA model, then add `--lora` to the above command. + +To use CPU instead of GPU, add `--device cpu`. + +To test the pretrained model (not finetuned), leave out `--exp_id`. For example: + +```bash +python inference.py --dataset_subset easy +``` + +## Validating the tokenizer + +Plot the distribution of dataset sample lengths, and see how many samples will be truncated by the tokenizer: + +```bash +python validate_tokenizer.py +``` + + +## Contributors + +- [Kevin Musgrave](https://github.com/KevinMusgrave) +- [Agnieszka Ciborowska](https://github.com/aciborowska) diff --git a/blog/llm-finetuning-4/chat_format.py b/blog/llm-finetuning-4/chat_format.py new file mode 100644 index 0000000..4ba4baf --- /dev/null +++ b/blog/llm-finetuning-4/chat_format.py @@ -0,0 +1,67 @@ +CHAT_ML_TEMPLATE = """ +{% for message in messages %} +{% if message['role'] == 'user' %} +{{'<|im_start|>user\n' + message['content'].strip() + '<|im_end|>' }} +{% elif message['role'] == 'system' %} +{{'<|im_start|>system\n' + message['content'].strip() + '<|im_end|>' }} +{% elif message['role'] == 'assistant' %} +{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }} +{% endif %} +{% endfor %} +""" + + +CHAT_ML_EOS_TOKEN = "<|im_end|>" + + +def get_chat_format(element, model_name, with_assistant_response=True): + system_prompt = ( + "You are a helpful programmer assistant that excels at SQL. " + "When prompted with a task and a definition of an SQL table, you " + "respond with a SQL query to retrieve information from the table. " + "Don't explain your reasoning, only provide the SQL query." + ) + + user_prompt = "Task: {instruction}\nSQL table: {input}\nSQL query: " + + if model_name == "mistralai/Mistral-7B-Instruct-v0.2": + user_prompt = f"{system_prompt}\n{user_prompt}" + output = [ + {"role": "user", "content": user_prompt.format_map(element)}, + ] + else: + output = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt.format_map(element)}, + ] + + if with_assistant_response: + output.append({"role": "assistant", "content": element["response"]}) + + return output + + +def set_special_tokens(tokenizer, model_name): + if model_name == "TinyLlama/TinyLlama-1.1B-Chat-v0.4": + tokenizer.chat_template = CHAT_ML_TEMPLATE + tokenizer.eos_token = CHAT_ML_EOS_TOKEN + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + +def get_assistant_prompt(model_name): + if model_name == "TinyLlama/TinyLlama-1.1B-Chat-v0.4": + return "<|im_start|>assistant\n" + else: + return "[/INST]" + + +def get_response_template_ids(tokenizer, model_name): + return tokenizer.encode(get_assistant_prompt(model_name), add_special_tokens=False) + + +def maybe_add_generation_prompt(x, model_name): + if model_name == "TinyLlama/TinyLlama-1.1B-Chat-v0.4": + return x + get_assistant_prompt(model_name) + else: + return x diff --git a/blog/llm-finetuning-4/dataset_utils.py b/blog/llm-finetuning-4/dataset_utils.py new file mode 100644 index 0000000..38e49da --- /dev/null +++ b/blog/llm-finetuning-4/dataset_utils.py @@ -0,0 +1,69 @@ +import datasets +import pandas as pd + + +def add_length_column(dataset) -> pd.DataFrame: + df = dataset.to_pandas() + df["total_length"] = 0 + for column_name in ["instruction", "input", "response"]: + num_words = df[column_name].astype(str).str.split().apply(len) + df["total_length"] += num_words + + return df + + +def filter_by_total_length(df, difficulty, number_of_samples): + if difficulty == "easy": + return df[df["total_length"].between(10, 100)].iloc[:number_of_samples] + elif difficulty == "medium": + return df[df["total_length"].between(101, 200)].iloc[:number_of_samples] + elif difficulty == "hard": + return df[df["total_length"].between(201, 800)].iloc[:number_of_samples] + + +def get_dataset_subset_name(difficulty: str) -> str: + return f"text-to-sql-v1-{difficulty}" + + +def create_and_save_datasets( + df, difficulty, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 +): + seed = 123 + # remove total_length column because we don't need it anymore + df = df.drop(columns=["total_length"]) + dataset = datasets.Dataset.from_pandas(df, preserve_index=False) + + # split into training and "the rest" + train_valtest = dataset.train_test_split(train_size=train_ratio, seed=seed) + + # split "the rest" into validation and testing + val_test = train_valtest["test"].train_test_split( + test_size=test_ratio / (test_ratio + val_ratio), seed=seed + ) + + dataset = datasets.DatasetDict( + { + "train": train_valtest["train"], + "valid": val_test["train"], + "test": val_test["test"], + } + ) + dataset_name = get_dataset_subset_name(difficulty) + dataset.save_to_disk(dataset_name) + return dataset + + +def load_dataset(difficulty): + return datasets.load_from_disk(get_dataset_subset_name(difficulty)) + + +def load_or_create_dataset(difficulty, num_samples=10000): + try: + return load_dataset(difficulty) + except FileNotFoundError: + dataset = datasets.load_dataset("Clinton/Text-to-sql-v1") + dataset = dataset["train"] + dataset = dataset.remove_columns(["text", "source"]) + df = add_length_column(dataset) + df = filter_by_total_length(df, difficulty, num_samples) + return create_and_save_datasets(df, difficulty) diff --git a/blog/llm-finetuning-4/deepspeed.yaml b/blog/llm-finetuning-4/deepspeed.yaml new file mode 100644 index 0000000..dfd031e --- /dev/null +++ b/blog/llm-finetuning-4/deepspeed.yaml @@ -0,0 +1,36 @@ +name: mistral deepspeed easy +debug: false +environment: + environment_variables: + - NCCL_DEBUG=INFO + image: determinedai/environments:cuda-11.8-pytorch-2.0-gpu-95c7a14 +resources: + slots_per_trial: 2 +searcher: + name: single + max_length: + batches: 5000 + metric: eval_accuracy + smaller_is_better: false +hyperparameters: + model: "mistralai/Mistral-7B-Instruct-v0.2" + dataset_subset: "easy" + lora: false + training_args: + output_dir: "/tmp/llm_finetuning" + max_steps: 5000 + per_device_train_batch_size: 2 + per_device_eval_batch_size: 4 + bf16: true + evaluation_strategy: "steps" + eval_steps: 1000 + logging_strategy: "steps" + logging_steps: 100 + save_strategy: "steps" + save_steps: 5000 + learning_rate: 1e-5 + deepspeed: "ds_configs/ds_config_stage_3.json" +entrypoint: >- + python -m determined.launch.deepspeed + python finetune.py +max_restarts: 0 \ No newline at end of file diff --git a/blog/llm-finetuning-4/ds_configs/ds_config_stage_1.json b/blog/llm-finetuning-4/ds_configs/ds_config_stage_1.json new file mode 100644 index 0000000..b1dc8c5 --- /dev/null +++ b/blog/llm-finetuning-4/ds_configs/ds_config_stage_1.json @@ -0,0 +1,48 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + } +} diff --git a/blog/llm-finetuning-4/ds_configs/ds_config_stage_2.json b/blog/llm-finetuning-4/ds_configs/ds_config_stage_2.json new file mode 100644 index 0000000..9a7bed0 --- /dev/null +++ b/blog/llm-finetuning-4/ds_configs/ds_config_stage_2.json @@ -0,0 +1,48 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + } +} diff --git a/blog/llm-finetuning-4/ds_configs/ds_config_stage_2_cpu_offload.json b/blog/llm-finetuning-4/ds_configs/ds_config_stage_2_cpu_offload.json new file mode 100644 index 0000000..78dd2a2 --- /dev/null +++ b/blog/llm-finetuning-4/ds_configs/ds_config_stage_2_cpu_offload.json @@ -0,0 +1,52 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + } +} diff --git a/blog/llm-finetuning-4/ds_configs/ds_config_stage_3.json b/blog/llm-finetuning-4/ds_configs/ds_config_stage_3.json new file mode 100644 index 0000000..a21bd3d --- /dev/null +++ b/blog/llm-finetuning-4/ds_configs/ds_config_stage_3.json @@ -0,0 +1,47 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto" +} diff --git a/blog/llm-finetuning-4/finetune.py b/blog/llm-finetuning-4/finetune.py new file mode 100644 index 0000000..3abbb7f --- /dev/null +++ b/blog/llm-finetuning-4/finetune.py @@ -0,0 +1,195 @@ +import logging +import os +import sys + +import datasets +import determined as det +import evaluate +import torch +import transformers +from determined.transformers import DetCallback +from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments +from trl import DataCollatorForCompletionOnlyLM + +from chat_format import get_chat_format, get_response_template_ids, set_special_tokens +from dataset_utils import load_or_create_dataset + +logger = logging.getLogger(__name__) + + +def get_tokenizer(model_name, model_commit_hash): + tokenizer = AutoTokenizer.from_pretrained( + model_name, + padding_side="right", + truncation_side="right", + revision=model_commit_hash, + ) + set_special_tokens(tokenizer, model_name) + return tokenizer + + +def get_model_and_tokenizer(model_name, use_lora, hparams, inference=False, device_map="auto", model_commit_hash=None): + if inference: + if use_lora: + model = AutoPeftModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16, device_map=device_map, revision=model_commit_hash + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device_map, + revision=model_commit_hash, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + revision=model_commit_hash, + ) + + if use_lora: + r = hparams["r"] + lora_alpha = r * hparams["lora_alpha_in_r"] + peft_config = LoraConfig( + task_type="CAUSAL_LM", + inference_mode=False, + r=r, + lora_alpha=lora_alpha, + lora_dropout=hparams["lora_dropout"], + ) + + model = get_peft_model(model, peft_config) + + tokenizer = get_tokenizer(model_name, model_commit_hash=model_commit_hash) + return model, tokenizer + + +def get_tokenize_fn(tokenizer): + def fn(formatted): + return tokenizer(formatted, padding=True, truncation=True, max_length=2048) + + return fn + + +def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] + return logits.argmax(dim=-1) + + +def main(training_args, det_callback, hparams): + if "hf_token" in hparams: + import huggingface_hub + + huggingface_hub.login(token=hparams["hf_token"]) + + model_name = hparams["model"] + model_commit_hash = None + if "model_commit_hash" in hparams: + model_commit_hash = hparams["model_commit_hash"] + model, tokenizer = get_model_and_tokenizer(model_name, hparams["lora"], hparams=hparams, model_commit_hash=model_commit_hash) + tokenize_fn = get_tokenize_fn(tokenizer) + + def tokenize(element): + formatted = tokenizer.apply_chat_template( + get_chat_format(element, model_name), tokenize=False + ) + outputs = tokenize_fn(formatted) + return { + "input_ids": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + } + + dataset = load_or_create_dataset(hparams["dataset_subset"]) + column_names = list(dataset["train"].features) + for k in dataset.keys(): + dataset[k] = dataset[k].map(tokenize, remove_columns=column_names) + + response_template_ids = get_response_template_ids(tokenizer, model_name) + collator = DataCollatorForCompletionOnlyLM( + response_template_ids, tokenizer=tokenizer + ) + + bleu = evaluate.load("bleu") + acc = evaluate.load("accuracy") + + def compute_metrics(eval_preds): + preds, labels = eval_preds + # preds have the same shape as the labels, after the argmax(-1) has been calculated + # by preprocess_logits_for_metrics but we need to shift the labels + labels = labels[:, 1:] + preds = preds[:, :-1] + # -100 is a default value for ignore_index used by DataCollatorForCompletionOnlyLM + mask = labels == -100 + labels[mask] = tokenizer.pad_token_id + preds[mask] = tokenizer.pad_token_id + + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + bleu_score = bleu.compute(predictions=decoded_preds, references=decoded_labels) + accuracy = acc.compute(predictions=preds[~mask], references=labels[~mask]) + + return {**bleu_score, **accuracy} + + trainer = Trainer( + args=training_args, + model=model, + tokenizer=tokenizer, + data_collator=collator, + train_dataset=dataset["train"], + eval_dataset=dataset["valid"], + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + compute_metrics=compute_metrics, + ) + + trainer.add_callback(det_callback) + + trainer.train() + + +if __name__ == "__main__": + # Setup logging + logging.basicConfig( + format=det.LOG_FORMAT, handlers=[logging.StreamHandler(sys.stdout)] + ) + log_level = logging.INFO + transformers.utils.logging.set_verbosity_info() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + info = det.get_cluster_info() + hparams = info.trial.hparams + + if "hf_token" in hparams: + import huggingface_hub + + huggingface_hub.login(token=hparams["hf_token"]) + + if hparams["training_args"]["deepspeed"]: + hparams["training_args"]["deepspeed"] = "ds_configs/ds_config_stage_3.json" + + training_args = TrainingArguments(**hparams["training_args"]) + if training_args.deepspeed: + # Set env var for deepspeed distributed context + os.environ["LOCAL_SIZE"] = os.environ["LOCAL_WORLD_SIZE"] + os.environ["CROSS_RANK"] = str(int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"])) + os.environ["CROSS_SIZE"] = str(int(os.environ["WORLD_SIZE"]) // int(os.environ["LOCAL_WORLD_SIZE"])) + os.environ["CHIEF_IP"] = os.environ["DET_CHIEF_IP"] + distributed = det.core.DistributedContext.from_deepspeed() + else: + distributed = det.core.DistributedContext.from_torch_distributed() + + with det.core.init(distributed=distributed) as core_context: + det_callback = DetCallback( + core_context, + training_args, + ) + main(training_args, det_callback, hparams) diff --git a/blog/llm-finetuning-4/inference.py b/blog/llm-finetuning-4/inference.py new file mode 100644 index 0000000..40d6a79 --- /dev/null +++ b/blog/llm-finetuning-4/inference.py @@ -0,0 +1,61 @@ +import argparse +import glob + +from determined.experimental import client + +from chat_format import get_chat_format, maybe_add_generation_prompt +from dataset_utils import load_or_create_dataset +from finetune import get_model_and_tokenizer + + +def main(exp_id, dataset_subset, lora, device): + model_name = "mistralai/Mistral-7B-Instruct-v0.2" + if exp_id is None: + checkpoint_dir = model_name + else: + exp = client.get_experiment(exp_id) + checkpoint = exp.list_checkpoints( + max_results=1, + sort_by=client.CheckpointSortBy.SEARCHER_METRIC, + order_by=client.OrderBy.DESCENDING, + )[0] + checkpoint_dir = checkpoint.download(mode=client.DownloadMode.MASTER) + checkpoint_dir = glob.glob(f"{checkpoint_dir}/checkpoint-*")[0] + + model, tokenizer = get_model_and_tokenizer( + checkpoint_dir, lora, inference=True, device_map=device + ) + + dataset = load_or_create_dataset(dataset_subset)["test"] + element = dataset[0] + formatted = tokenizer.apply_chat_template( + get_chat_format( + {"instruction": element["instruction"], "input": element["input"]}, + model_name, + with_assistant_response=False, + ), + tokenize=False, + ) + formatted = maybe_add_generation_prompt(formatted, model_name) + print(formatted) + + inputs = tokenizer(formatted, return_tensors="pt").to(device) + outputs = model.generate( + **inputs, eos_token_id=tokenizer.eos_token_id, max_new_tokens=1000 + ) + input_length = inputs["input_ids"].shape[1] + response = tokenizer.batch_decode( + outputs[:, input_length:], skip_special_tokens=True + ) + print(f"\n\nCorrect response:\n{element['response']}") + print(f"\n\nLLM response:\n{response[0]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--exp_id", type=int, default=None, required=False) + parser.add_argument("--dataset_subset", type=str, default="easy", required=False) + parser.add_argument("--lora", action="store_true") + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + main(args.exp_id, args.dataset_subset, args.lora, args.device) diff --git a/blog/llm-finetuning-4/lora.yaml b/blog/llm-finetuning-4/lora.yaml new file mode 100644 index 0000000..5d6ef2c --- /dev/null +++ b/blog/llm-finetuning-4/lora.yaml @@ -0,0 +1,50 @@ +name: mistral lora easy +debug: false +environment: + environment_variables: + - NCCL_DEBUG=INFO + image: + gpu: determinedai/environments:cuda-11.8-pytorch-2.0-gpu-95c7a14 + cpu: determinedai/environments:py-3.10-pytorch-2.0-cpu-03ae7d7 +resources: + slots_per_trial: 2 +searcher: + name: grid + max_length: + batches: 5000 + metric: eval_accuracy + smaller_is_better: false +hyperparameters: + model: "mistralai/Mistral-7B-Instruct-v0.2" + model_commit_hash: "99259002b41e116d28ccb2d04a9fbe22baed0c7f" + dataset_subset: "easy" + lora: true + # Tunable hyperparameters + r: + type: categorical + vals: [1, 2, 4, 8, 16, 32, 64] + lora_alpha_in_r: + type: categorical + vals: [0.5, 1, 2] + lora_dropout: + type: categorical + vals: [0.1] + # End tunable hyperparameters + training_args: + output_dir: "/tmp/llm_finetuning" + max_steps: 5000 + per_device_train_batch_size: 8 + per_device_eval_batch_size: 4 + fp16: true + evaluation_strategy: "steps" + eval_steps: 1000 + logging_strategy: "steps" + logging_steps: 100 + save_strategy: "steps" + save_steps: 1000 + learning_rate: 1e-5 + deepspeed: true +entrypoint: >- + python -m determined.launch.torch_distributed + python finetune.py +max_restarts: 0 \ No newline at end of file diff --git a/blog/llm-finetuning-4/requirements.txt b/blog/llm-finetuning-4/requirements.txt new file mode 100644 index 0000000..c6cad1b --- /dev/null +++ b/blog/llm-finetuning-4/requirements.txt @@ -0,0 +1,8 @@ +transformers==4.37.2 +datasets==2.17.0 +evaluate==0.4.1 +trl==0.7.10 +scikit-learn==1.4.0 +deepspeed==0.10.2 +peft==0.8.2 +huggingface_hub \ No newline at end of file diff --git a/blog/llm-finetuning-4/startup-hook.sh b/blog/llm-finetuning-4/startup-hook.sh new file mode 100644 index 0000000..50aa896 --- /dev/null +++ b/blog/llm-finetuning-4/startup-hook.sh @@ -0,0 +1,3 @@ +#!/bin/bash +pip install --upgrade pip +pip install -r requirements.txt diff --git a/blog/llm-finetuning-4/validate_tokenizer.py b/blog/llm-finetuning-4/validate_tokenizer.py new file mode 100644 index 0000000..f73f3ff --- /dev/null +++ b/blog/llm-finetuning-4/validate_tokenizer.py @@ -0,0 +1,126 @@ +from collections import defaultdict +from pprint import pprint + +import matplotlib.pyplot as plt +import numpy as np +import torch +import tqdm + +from chat_format import get_assistant_prompt, get_chat_format, get_response_template_ids +from dataset_utils import load_or_create_dataset +from finetune import get_tokenize_fn, get_tokenizer + +model_name = "mistralai/Mistral-7B-Instruct-v0.2" +tokenizer = get_tokenizer(model_name) +tokenize_fn = get_tokenize_fn(tokenizer) +num_missing_response_template = defaultdict(lambda: defaultdict(int)) +num_incomplete = defaultdict(lambda: defaultdict(int)) +num_tokens = defaultdict(lambda: defaultdict(list)) +num_tokens_before_response = defaultdict(lambda: defaultdict(list)) +num_tokens_with_padding_and_truncation = defaultdict(lambda: defaultdict(list)) + + +def to_str(ids): + return ",".join([str(i) for i in ids]) + + +def plot_histogram( + data, + bins, + title, + filename_prefix, +): + hist_data, bin_edges = np.histogram(data, bins=bins) + plt.figure() + plt.bar( + (bin_edges[:-1] + bin_edges[1:]) / 2, + hist_data, + width=np.diff(bin_edges), + edgecolor="black", + ) + plt.xlabel("Bin") + plt.ylabel("Frequency") + plt.title(title) + plt.savefig(f"{filename_prefix}.png") + plt.close() + + return bin_edges + + +def get_collate_fn(difficulty, split): + def fn(x): + formatted = [] + before_response_formatted = [] + for e in x: + with_chat_template = tokenizer.apply_chat_template( + get_chat_format(e, model_name), tokenize=False + ) + formatted.append(with_chat_template) + before_response_formatted.append( + with_chat_template.split(get_assistant_prompt(model_name))[0] + ) + untruncated = tokenizer(formatted, padding=False, truncation=False)["input_ids"] + before_response_untruncated = tokenizer( + before_response_formatted, + padding=False, + truncation=False, + )["input_ids"] + element = tokenize_fn(formatted)["input_ids"] + response_template = to_str(get_response_template_ids(tokenizer, model_name)) + for i, e in enumerate(element): + num_tokens[difficulty][split].append(len(untruncated[i])) + num_tokens_before_response[difficulty][split].append( + len(before_response_untruncated[i]) + ) + num_tokens_with_padding_and_truncation[difficulty][split].append(len(e)) + if response_template not in to_str(e): + num_missing_response_template[difficulty][split] += 1 + decoded = tokenizer.decode(e) + if x[i]["response"] not in decoded: + num_incomplete[difficulty][split] += 1 + + return element + + return fn + + +def validate(): + batch_size = 4 + for difficulty in ["easy", "medium", "hard"]: + dataset = load_or_create_dataset(difficulty) + for split in ["train", "valid", "test"]: + print(difficulty, split) + dataloader = torch.utils.data.DataLoader( + dataset[split], + batch_size=batch_size, + collate_fn=get_collate_fn(difficulty, split), + ) + for _ in tqdm.tqdm(dataloader): + pass + + plot_histogram( + np.array(num_tokens[difficulty][split]), + bins=100, + title=f"{difficulty} {split} # Tokens", + filename_prefix=f"{difficulty}_{split}_tokens", + ) + + plot_histogram( + np.array(num_tokens_before_response[difficulty][split]), + bins=100, + title=f"{difficulty} {split} # Tokens Before Response", + filename_prefix=f"{difficulty}_{split}_tokens_before_response", + ) + + plot_histogram( + np.array(num_tokens_with_padding_and_truncation[difficulty][split]), + bins=100, + title=f"{difficulty} {split} # Tokens with Padding & Truncation", + filename_prefix=f"{difficulty}_{split}_tokens_with_pad_trunc_{batch_size}", + ) + + pprint(num_missing_response_template) + pprint(num_incomplete) + + +validate()