Skip to content

Commit

Permalink
Adding apply_torchchat_tp
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Nov 2, 2024
1 parent 76ce752 commit febda82
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
57 changes: 51 additions & 6 deletions test/generate/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,51 @@
from typing import Optional

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)


def apply_torchchat_tp(model: nn.Module, tp_mesh: DeviceMesh):
# As implemented in torchchat
# https://github.com/pytorch/torchchat/blob/main/torchchat/model.py#L679

parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
},
)

for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)


def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor:
q = torch.empty_like(probs).exponential_(1)
def multinomial_sample_one(
probs: torch.Tensor, rng: Optional[torch.Generator] = None
) -> torch.Tensor:
q = torch.empty_like(probs).exponential_(1, generator=rng)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long)


Expand All @@ -19,7 +60,6 @@ def logits_to_probs(
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:

logits = logits / max(temperature, 1e-5)

if top_k is not None:
Expand All @@ -37,11 +77,11 @@ def generate_next_token(
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
rng: Optional[torch.Generator] = None,
) -> torch.Tensor:

logits = model(x) # (B, T, vocab_size)
probs = logits_to_probs(logits[:, -1, :], temperature, top_k)
next_token = multinomial_sample_one(probs)
next_token = multinomial_sample_one(probs, rng=rng)
return next_token


Expand All @@ -53,12 +93,16 @@ def generate(
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
seed: Optional[int] = None,
) -> torch.Tensor:

# ensure batch dimension (T,) --> (B, T)
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)

rng = None
if seed is not None:
rng = torch.Generator(input_ids.device).manual_seed(seed)

generated_tokens = input_ids.clone()

for _ in range(max_new_tokens):
Expand All @@ -67,6 +111,7 @@ def generate(
x=generated_tokens,
temperature=temperature,
top_k=top_k,
rng=rng,
)

generated_tokens = torch.cat([generated_tokens, next_token], dim=1)
Expand Down
14 changes: 10 additions & 4 deletions test/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate.generation import generate
from generate.generation import apply_torchchat_tp, generate


@record
Expand All @@ -57,9 +57,9 @@ def example_generate(
utils.set_determinism(seed)

if seed is None:
logger.info("Deterministic off")
logger.info("Deterministic sampling off")
else:
logger.info(f"Deterministic on. Using seed: {seed}")
logger.info(f"Deterministic sampling on. Using seed: {seed}")

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
Expand Down Expand Up @@ -103,7 +103,12 @@ def example_generate(
model = model_cls.from_model_args(model_config)

if world_size > 1:
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config)

use_torchchat_tp = False
if use_torchchat_tp:
apply_torchchat_tp(model, world_mesh["tp"]) # Working
else:
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config)

# materalize model
model.to_empty(device="cuda")
Expand Down Expand Up @@ -147,6 +152,7 @@ def example_generate(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
seed=seed,
)
t1 = time.monotonic()
elapsed_sec = t1 - t0
Expand Down

0 comments on commit febda82

Please sign in to comment.