Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epochs to levanter #768

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
train_urls:
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/"
tokenizer: "allenai/OLMo-1B"
model: # 7B class model
type: llama
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 1024
use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 256
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 5000

epoch: False
4 changes: 3 additions & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ def _prepare_example(ex: dict) -> LmExample:
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
if config.mask_inputs:
loss_mask = hax.arange(Pos) >= ex["source_lens"]
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
# to not predict EOS token since we don't have target!
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
Expand Down
51 changes: 50 additions & 1 deletion src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Callable, Optional

Expand All @@ -17,7 +18,7 @@
from tqdm_loggable.auto import tqdm

import levanter.tracker
from levanter.data import DataLoader
from levanter.data import DataLoader, AsyncDataset
from levanter.logging import save_xla_dumps_to_wandb
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
Expand All @@ -30,6 +31,54 @@
logger = pylogging.getLogger(__name__)


def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size):
total_tokens = None

def log_epoch(step_info: StepInfo):
nonlocal total_tokens
if total_tokens is None:
if not total_tokens_future.done():
if step_info.step % 1000 == 0:
logger.info("Dataset not finished. Can't compute epochs.")
return # We don't have the total tokens yet, so we can't calculate epoch
dlwh marked this conversation as resolved.
Show resolved Hide resolved
total_tokens = total_tokens_future.result()

# Get the total processed tokens from the metrics logged by log_performance_stats
processed_tokens = tokens_per_example * batch_size * step_info.step
if processed_tokens is None:
return # No token count available yet

current_epoch = processed_tokens / total_tokens
levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step)

return log_epoch


def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int):
if not ds.is_finite():
raise ValueError("Epochs don't make sense with an infinite dataset.")

def log_length():
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# If ds.async_len() is the only option, run it in an event loop inside the thread
import asyncio

async def compute_length():
length = await ds.async_len()
return length

# Run the async function synchronously in this thread
length = asyncio.run(compute_length())
total_tokens = length * seq_length
levanter.tracker.log_summary({"dataset/total_tokens": total_tokens})
return total_tokens

# Create a ThreadPoolExecutor with a single worker thread
executor = ThreadPoolExecutor(max_workers=1)
# Submit the log_length function to be executed in a separate thread
future = executor.submit(log_length)
return future


def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
total_load_time = 0.0
Expand Down
91 changes: 90 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,84 @@
DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index


class EpochDataset(AsyncDataset[T_co]):
"""
A dataset that wraps another dataset, providing infinite epochs by recycling indices.
If `max_epochs` is specified, it limits the number of cycles before raising StopIteration.

:param dataset: The dataset to wrap.
:param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
"""
def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
if dataset.is_finite():
raise ValueError("Cannot apply epoching to a finite dataset.")
self.dataset = dataset
self.max_epochs = max_epochs

async def async_len(self) -> int:
if self.max_epochs is None:
raise ValueError("Cannot determine length of an infinite dataset without max_epochs.")
# Return the total number of samples: max_epochs * length of the dataset
return self.max_epochs * await self.dataset.async_len()

async def final_length_is_known(self) -> bool:
return await self.dataset.final_length_is_known()

def is_finite(self) -> bool:
# EpochDataset can be finite if max_epochs is set.
return self.max_epochs is not None

async def current_len(self) -> Optional[int]:
# If max_epochs is None, the dataset is effectively infinite.
if self.max_epochs is None:
return None

# If the final length of the dataset is not known, return the current length of the underlying dataset.
if not await self.dataset.final_length_is_known():
return await self.dataset.current_len()

# If the final length is known, return the max_epochs * async_len of the dataset.
return self.max_epochs * await self.dataset.async_len()

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
# Use self.wait_until_len_at_least to ensure we have enough data for the batch.
max_index = max(indices)
ds_len = await self.wait_until_len_at_least(max_index + 1)

# Determine the epoch based on the largest index
epoch = max_index // ds_len

# If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs
if self.max_epochs is not None and epoch >= self.max_epochs:
raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}")

# Wrap the indices within the bounds of the dataset length
wrapped_indices = [idx % ds_len for idx in indices]

# Delegate to the underlying dataset's get_batch
return await self.dataset.get_batch(wrapped_indices)

async def wait_until_len_at_least(self, length: int) -> int:
"""
Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length.
If the dataset's actual length is less than `length`, it returns the minimum of async_len and the current length.
"""
# Wait until the underlying dataset's length is at least `length`
if not self.is_finite():
return length

if await self.dataset.final_length_is_known():
base_length = await self.dataset.async_len()
else:
base_length = await self.dataset.wait_until_len_at_least(length)

if base_length < length:
# hit epoch boundary
assert self.max_epochs is not None
return self.max_epochs * base_length

return base_length

class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from an underlying TreeCache.
Expand Down Expand Up @@ -648,9 +726,20 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray] = None,
epochs: bool = False,
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
) -> AsyncDataset[np.ndarray]:

ds = self.token_seq_dataset("train", seq_len, monitors)
if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds)
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved

# add epoch flag here.
if ds is None:
raise ValueError("No training set!")

Expand Down
16 changes: 15 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TrainLmConfig:
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
initialize_from_checkpoint_path: Optional[str] = None
# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: bool | int = False
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved


def main(config: TrainLmConfig):
Expand Down Expand Up @@ -117,8 +118,20 @@ def main(config: TrainLmConfig):

# TODO: fix this
tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size)
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)

train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch),
Pos,
KeyPos,
ignore_index=config.data.ignore_token_id,
)


# add epoch logging
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
trainer.add_hook(
callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1
)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
Expand Down Expand Up @@ -236,6 +249,7 @@ def compute_log_probs(model, example):

## OK, actually run training!
trainer.train(state, train_loader)

# checkpointer.on_step(last_step, force=True)


Expand Down
1 change: 0 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi
while int(state.step) < self.num_train_steps:
with capture_time() as loading_time:
example = next(iter_data)

info = self.train_step(state, example)
state = info.state

Expand Down
Loading