From e14e39f68e1d5fb0cd39b7421ca15b3e0bd73f6f Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 21 Oct 2024 18:05:59 -0500 Subject: [PATCH] Fix step iteration bug in finetuning scripts (#1794) --- litgpt/finetune/adapter.py | 5 ++++- litgpt/finetune/adapter_v2.py | 5 ++++- litgpt/finetune/full.py | 4 +++- litgpt/finetune/lora.py | 4 +++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 4f938e54cb..4ff66088dd 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -268,10 +268,13 @@ def fit( total_lengths = 0 total_t0 = time.perf_counter() - while step_count < max_steps and train_iterator.epoch < train.epochs: + while step_count < max_steps: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) + if train_iterator.epoch >= train.epochs: + break + input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index d5451d5987..c8e3415ead 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -268,10 +268,13 @@ def fit( total_lengths = 0 total_t0 = time.perf_counter() - while step_count < max_steps and train_iterator.epoch < train.epochs: + while step_count < max_steps: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) + if train_iterator.epoch >= train.epochs: + break + input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0 diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 8d864bcec8..388675fe57 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -244,10 +244,12 @@ def fit( ) fabric.barrier() - while state["step_count"] < max_steps and train_iterator.epoch < train.epochs: + while state["step_count"] < max_steps: state["iter_num"] += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) + if train_iterator.epoch >= train.epochs: + break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 566eba6730..418a192344 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -299,10 +299,12 @@ def fit( total_lengths = 0 total_t0 = time.perf_counter() - while step_count < max_steps and train_iterator.epoch < train.epochs: + while step_count < max_steps: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) + if train_iterator.epoch >= train.epochs: + break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0