From febda824c92b8f4f9bc5f3b66641b06e517c8b7e Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Sat, 2 Nov 2024 02:56:08 +0000 Subject: [PATCH] Adding apply_torchchat_tp --- test/generate/generation.py | 57 ++++++++++++++++++++++++++++++---- test/generate/test_generate.py | 14 ++++++--- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/test/generate/generation.py b/test/generate/generation.py index 96955147..571ec302 100644 --- a/test/generate/generation.py +++ b/test/generate/generation.py @@ -7,10 +7,51 @@ from typing import Optional import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._tensor import Replicate +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) + + +def apply_torchchat_tp(model: nn.Module, tp_mesh: DeviceMesh): + # As implemented in torchchat + # https://github.com/pytorch/torchchat/blob/main/torchchat/model.py#L679 + + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + }, + ) + + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(), + "feed_forward.w3": ColwiseParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) -def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor: - q = torch.empty_like(probs).exponential_(1) +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) @@ -19,7 +60,6 @@ def logits_to_probs( temperature: float = 1.0, top_k: Optional[int] = None, ) -> torch.Tensor: - logits = logits / max(temperature, 1e-5) if top_k is not None: @@ -37,11 +77,11 @@ def generate_next_token( *, temperature: float = 1.0, top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, ) -> torch.Tensor: - logits = model(x) # (B, T, vocab_size) probs = logits_to_probs(logits[:, -1, :], temperature, top_k) - next_token = multinomial_sample_one(probs) + next_token = multinomial_sample_one(probs, rng=rng) return next_token @@ -53,12 +93,16 @@ def generate( max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, + seed: Optional[int] = None, ) -> torch.Tensor: - # ensure batch dimension (T,) --> (B, T) if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + generated_tokens = input_ids.clone() for _ in range(max_new_tokens): @@ -67,6 +111,7 @@ def generate( x=generated_tokens, temperature=temperature, top_k=top_k, + rng=rng, ) generated_tokens = torch.cat([generated_tokens, next_token], dim=1) diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py index 631beea9..6d0c7486 100644 --- a/test/generate/test_generate.py +++ b/test/generate/test_generate.py @@ -30,7 +30,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from generate.generation import generate +from generate.generation import apply_torchchat_tp, generate @record @@ -57,9 +57,9 @@ def example_generate( utils.set_determinism(seed) if seed is None: - logger.info("Deterministic off") + logger.info("Deterministic sampling off") else: - logger.info(f"Deterministic on. Using seed: {seed}") + logger.info(f"Deterministic sampling on. Using seed: {seed}") world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) @@ -103,7 +103,12 @@ def example_generate( model = model_cls.from_model_args(model_config) if world_size > 1: - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) + + use_torchchat_tp = False + if use_torchchat_tp: + apply_torchchat_tp(model, world_mesh["tp"]) # Working + else: + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) # materalize model model.to_empty(device="cuda") @@ -147,6 +152,7 @@ def example_generate( temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, + seed=seed, ) t1 = time.monotonic() elapsed_sec = t1 - t0