Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epochs to levanter #768

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Add epochs to levanter #768

wants to merge 22 commits into from

Conversation

ahmeda14960
Copy link
Contributor

adds epochs with a boolean flag, which will continue epoching over the dataset and tracks epochs throughout training. Should be backwards compatible with checkpoints.

@ahmeda14960 ahmeda14960 marked this pull request as ready for review October 16, 2024 23:34
config/llama_7b_with_olmo_config.yaml Outdated Show resolved Hide resolved
config/llama_7b_with_olmo_config.yaml Outdated Show resolved Hide resolved
src/levanter/callbacks.py Show resolved Hide resolved
src/levanter/callbacks.py Outdated Show resolved Hide resolved
src/levanter/callbacks.py Show resolved Hide resolved
@@ -63,6 +63,57 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index

class TokenSeqEpochDataset(AsyncDataset[np.ndarray]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just make EpochDataset that wraps an arbitrary dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatgpt and i made this

from typing import Sequence, Optional, TypeVar
import asyncio
import numpy as np

T_co = TypeVar('T_co', covariant=True)

class EpochDataset(AsyncDataset[T_co]):
    """
    A dataset that wraps another dataset, providing infinite epochs by recycling indices.
    If `max_epochs` is specified, it limits the number of cycles before raising StopIteration.
    
    :param dataset: The dataset to wrap.
    :param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
    """
    def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
        if dataset.is_finite():
            raise ValueError("Cannot apply epoching to a finite dataset.")

        self.dataset = dataset
        self.max_epochs = max_epochs

    async def async_len(self) -> int:
        if self.max_epochs is None:
            raise ValueError("Cannot determine length of an infinite dataset without max_epochs.")
        # Return the total number of samples: max_epochs * length of the dataset
        return self.max_epochs * await self.dataset.async_len()

    async def final_length_is_known(self) -> bool:
        return await self.dataset.final_length_is_known()

    def is_finite(self) -> bool:
        # EpochDataset can be finite if max_epochs is set.
        return self.max_epochs is not None

    async def current_len(self) -> Optional[int]:
        # If max_epochs is None, the dataset is effectively infinite.
        if self.max_epochs is None:
            return None

        # If the final length of the dataset is not known, return the current length of the underlying dataset.
        if not await self.dataset.final_length_is_known():
            return await self.dataset.current_len()

        # If the final length is known, return the max_epochs * async_len of the dataset.
        return self.max_epochs * await self.dataset.async_len()

    async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
        # Use self.wait_until_len_at_least to ensure we have enough data for the batch.
        max_index = max(indices)
        ds_len = await self.wait_until_len_at_least(max_index + 1)

        # Determine the epoch based on the largest index
        epoch = max_index // ds_len

        # If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs
        if self.max_epochs is not None and epoch >= self.max_epochs:
            raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}")

        # Wrap the indices within the bounds of the dataset length
        wrapped_indices = [idx % ds_len for idx in indices]

        # Delegate to the underlying dataset's get_batch
        return await self.dataset.get_batch(wrapped_indices)

    async def wait_until_len_at_least(self, length: int) -> int:
        """
        Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length.

        If the dataset's actual length is less than `length`, it returns the minimum of async_len and the current length.
        """
        # Wait until the underlying dataset's length is at least `length`
        if not self.is_finite(): return length
        
        if await self.dataset.final_length_is_known(): 
            base_length = await self.dataset.async_len()
        else:
            base_length = await self.dataset.wait_until_len_at_least(length)

        if base_length < length:  # hit epoch boundary
            assert self.max_epochs is not None
            return self.max_epochs * base_length

        return base_length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I removed the "cannot apply epoching to a finite dataset" since that seems like a bug

src/levanter/data/text.py Outdated Show resolved Hide resolved
src/levanter/main/train_lm.py Outdated Show resolved Hide resolved
src/levanter/main/train_lm.py Outdated Show resolved Hide resolved
src/levanter/main/train_lm.py Outdated Show resolved Hide resolved
src/levanter/data/text.py Outdated Show resolved Hide resolved
src/levanter/data/text.py Outdated Show resolved Hide resolved
src/levanter/main/train_lm.py Outdated Show resolved Hide resolved
@@ -27,6 +27,7 @@

from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore
from levanter.types import FilterSpec
# from levanter.trainer import StepInfo
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm

return # Can't calculate epochs without dataset size

# Calculate current epoch from steps without modifying StepInfo
current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably just be tracking this explicilty in stepinfo, but this is fine right now

@@ -27,6 +27,7 @@

from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore
from levanter.types import FilterSpec
# from levanter.trainer import StepInfo
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# from levanter.trainer import StepInfo

@dlwh dlwh mentioned this pull request Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants