Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implements Roberta Model #679

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d6284ca
[WIP] Implements Roberta Model
Jul 30, 2024
8f7402e
Implements dynamic masking objective
prady-saligram Jul 30, 2024
670b053
Implements dynamic masked dataset
prady-saligram Jul 30, 2024
42f5404
Reintroduced accidentally deleted CausalLMDataset class
prady-saligram Jul 30, 2024
9ad06af
Everything works except stuck on the final method,
Aug 1, 2024
53fd8d2
[WIP] Re-implements MLM training objective
prady-saligram Aug 5, 2024
dcd45b2
Adds error handling and reverts LmExample class to original
prady-saligram Aug 6, 2024
6f21e0d
Testing Modifications
Aug 13, 2024
730d847
Merge branch 'stanford-crfm:main' into roberta-model
prady-saligram Aug 26, 2024
027b176
Sets RobertaConfig as model architecture and creates default config file
prady-saligram Aug 26, 2024
399e08c
Adds compute_loss to roberta and changes positional ids to begin from 0
prady-saligram Sep 1, 2024
cd4118c
Investingating precision loss over time within the model using output…
Sep 4, 2024
96522f1
Merge branch 'roberta-model' of https://github.com/JulienDarve/levant…
Sep 4, 2024
8a732e5
Model can now successfully import weights from huggingface + made att…
Sep 10, 2024
5f3d8a2
Merge branch 'roberta-training' into roberta-model-copy-2
Sep 10, 2024
6c105f5
trial
Sep 12, 2024
ab85079
update 1
Sep 12, 2024
5b97400
update 2
Sep 12, 2024
bd7d411
update 3
Sep 12, 2024
b5d8e14
update
Sep 12, 2024
8717c3f
update
Sep 12, 2024
10c130c
update
Sep 12, 2024
834d88d
update
Sep 12, 2024
47fe23b
update
Sep 12, 2024
fb5c55c
update
Sep 12, 2024
8594e79
update
Sep 12, 2024
3ae80d7
update
Sep 12, 2024
de93fc9
update
Sep 12, 2024
896af7d
update
Sep 12, 2024
0be9a83
update
Sep 12, 2024
0c94a47
update
Sep 12, 2024
7ae681d
Training works!
JulienDarve Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions config/roberta-tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
id: dlwh/wikitext_103_detokenized
# train_urls:
# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
# validation_urls:
# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "cache/roberta-tiny"
tokenizer: "roberta-base"

model:
type: roberta
vocab_size: 50265
hidden_size: 32
intermediate_size: 64
num_hidden_layers: 4
num_attention_heads: 2
max_position_embeddings: 512
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
gradient_checkpointing: true

trainer:
tracker:
- type: wandb
project: "levanter"
tags: ["openwebtext", "roberta", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 32
num_train_steps: 20000

optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
34 changes: 34 additions & 0 deletions config/roberta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data:
id: dlwh/wikitext_103_detokenized
tokenizer: "roberta-base"

model:
type: roberta
vocab_size: 50265
hidden_size: 768
intermediate_size: 3072
num_hidden_layers: 12
num_attention_heads: 12
max_position_embeddings: 512
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
gradient_checkpointing: true

trainer:
tracker:
- type: wandb
project: "levanter"
tags: ["openwebtext", "roberta", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 32
num_train_steps: 20000

optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
91 changes: 81 additions & 10 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import equinox as eqx
import fsspec
import jax
import jax.numpy as jnp
import numpy as np
import pyarrow as pa
import regex
Expand All @@ -25,13 +26,11 @@

from levanter.data.mixture import MixtureDataset, StopStrategy

# intercept the logging nonsense here
from levanter.logging import silence_transformer_nag # noqa
from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmExample
from levanter.models.lm_model import MaskedLmExample, LmExample
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer


silence_transformer_nag() # noqa
from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa

Expand All @@ -47,7 +46,6 @@
from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa
from levanter.utils.jax_utils import use_cpu_device # noqa


logger = logging.getLogger("levanter.data.text")

# TASKS:
Expand All @@ -58,6 +56,83 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index

class MaskedLmDataset(ShardableDataset[MaskedLmExample]):
def __init__(
self,
dataset: ShardableDataset[np.ndarray],
QPos: Axis,
KPos: Axis,
mask_token_id: int,
mask_prob: float = 0.15,
noise_prob: float = 0.1,
key: Optional[PRNGKeyArray] = None,
# ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX,
):
self.dataset = dataset
self.QPos = QPos
self.KPos = KPos
self.mask_prob = mask_prob
self.noise_prob = noise_prob
self.key = key
self.mask_token_id = mask_token_id

if self.mask_prob > 0.0 and self.key is None:
raise ValueError("must provide key if mask_prob > 0.0")

def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset":
return MaskedLmDataset(
self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos,
self.mask_token_id,
self.mask_prob, self.noise_prob, self.key
)

def __iter__(self) -> Iterator[MaskedLmExample]:
key = self.key
sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])

with use_cpu_device():
@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_mlm_example(tokens, key):
tokens_array = tokens.array
targets = tokens_array.copy()

if self.mask_prob > 0:
this_key, key = jax.random.split(key)
mask_shape = tokens_array.shape
mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape)

rand = jax.random.uniform(this_key, mask_shape)
mask_token = jnp.where(rand < 0.8, self.mask_token_id, tokens_array)
random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1)
mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token)
masked_tokens = jnp.where(mask, mask_token, tokens_array)

# Set targets to the original tokens where mask is True, otherwise set to mask_token_id
targets = jnp.where(mask, tokens_array, self.mask_token_id)

masked_tokens_named = hax.named(masked_tokens, self.QPos)
targets_named = hax.named(targets, self.QPos)

attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos))

example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask)
else:
targets_named = hax.named(targets, self.QPos)
attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos))

example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask)

return example

for tokens in self.dataset:
tokens_array = jnp.array(tokens)
tokens_named = hax.named(tokens_array, self.QPos)
example = _create_mlm_example(tokens_named, key)
yield example



class CausalLmDataset(ShardableDataset[LmExample]):
def __init__(
Expand Down Expand Up @@ -89,18 +164,13 @@ def __iter__(self) -> Iterator[LmExample]:
sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])

with use_cpu_device():

@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_lm_example(tokens, key):
tokens = hax.named(tokens, self.QPos)

example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)

if self.fcm_prob > 0:
# masks for attention
# We support forgetful causal masking (FCM) which is a technique that improves training speed by
# randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention
# mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432
assert self.key is not None
this_key, key = jax.random.split(key)
fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key)
Expand All @@ -114,6 +184,7 @@ def _create_lm_example(tokens, key):
yield example



class TokenSeqDataset(ShardableDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache.
Expand Down Expand Up @@ -826,4 +897,4 @@ def build_caches(

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
return self.configs
return self.configs
Loading