Skip to content

Commit

Permalink
Zero copy initialization of models onto training workers for LLMs (#3469
Browse files Browse the repository at this point in the history
)

Co-authored-by: Geoffrey Angus <[email protected]>
Co-authored-by: Travis Addair <[email protected]>
  • Loading branch information
3 people authored Jul 22, 2023
1 parent 354627a commit 8696a72
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 24 deletions.
4 changes: 2 additions & 2 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,13 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
test_set=test_set,
save_path=model_dir,
)
(self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats

# Calibrates output feature probabilities on validation set if calibration is enabled.
# Must be done after training, and before final model parameters are saved.
if self.backend.is_coordinator():
calibrator = Calibrator(
trainer.model,
self.model,
self.backend,
batch_size=trainer.eval_batch_size,
)
Expand Down Expand Up @@ -684,7 +685,6 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
# List[TrainerMetric], with one entry per training checkpoint, according to steps_per_checkpoint.
# We reduce the dictionary of TrainerMetrics to a simple list of floats for interfacing with Ray
# Tune.
(self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats
train_stats = TrainingStats(
metric_utils.reduce_trainer_metrics_dict(train_trainset_stats),
metric_utils.reduce_trainer_metrics_dict(train_valiset_stats),
Expand Down
24 changes: 21 additions & 3 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def train_fn(
if test_shard is not None:
test_shard = RayDatasetShard(test_shard, features, training_set_metadata)

model = ray.get(model_ref)
# Deserialize the model (minus weights) from Plasma
# Extract the weights from Plasma (without copying data)
# Load the weights back into the model in-place on the current device (CPU)
model = distributed.replace_model_from_serialization(ray.get(model_ref))
model = distributed.to_device(model)

trainer = remote_trainer_cls(
Expand Down Expand Up @@ -339,6 +342,7 @@ def __init__(self, trainer_kwargs: Dict[str, Any]) -> None:
trainer_kwargs = copy.copy(trainer_kwargs)
self.backend_config = trainer_kwargs.pop("backend", None)
self.strategy = trainer_kwargs.pop("strategy", get_default_strategy_name())
self.dist_strategy = get_dist_strategy(self.strategy)

if "max_retries" in trainer_kwargs:
logger.warning("`max_retries` is no longer supported as a trainer argument in Ray backend. Ignoring it.")
Expand Down Expand Up @@ -408,7 +412,7 @@ def run(

callbacks = callbacks or []

trainer_cls, kwargs = get_dist_strategy(self.strategy).get_trainer_cls(self.backend_config)
trainer_cls, kwargs = self.dist_strategy.get_trainer_cls(self.backend_config)
train_loop_config = {**config, "distributed_strategy": self.strategy}
trainer = trainer_cls(
train_loop_per_worker=train_loop_per_worker,
Expand Down Expand Up @@ -475,11 +479,18 @@ def train(
stream_window_size["test"] = test_set.window_size_bytes

with create_runner(**self.trainer_kwargs) as runner:
# Extract weights as numpy tensors and place them in the Ray object store.
# If we store the weights of a model as NumPy arrays on Plasma, we can access those
# weights directly out of Plasma’s shared memory segments, without making any copies.
# This enables zero copy model loading on each training worker using shared
# memory from the Ray object store for model initialization.
dist_strategy = runner.dist_strategy
model_ref = ray.put(dist_strategy.extract_model_for_serialization(self.model))
trainer_results = runner.run(
lambda config: train_fn(**config),
config={
"executable_kwargs": executable_kwargs,
"model_ref": ray.put(self.model),
"model_ref": model_ref,
"remote_trainer_cls": self.remote_trainer_cls,
**kwargs,
},
Expand All @@ -489,6 +500,13 @@ def train(
stream_window_size=stream_window_size,
)

# re-register the weights of the model object in the main process
self.model = dist_strategy.replace_model_from_serialization(ray.get(model_ref))

# ensure module is initialized exactly as it is in the trainer process
# so that the state dict can be loaded back into the model correctly.
self.model.prepare_for_training()

# Set validation field and metric used by trainer
self._validation_field = trainer_results.metrics["validation_field"]
self._validation_metric = trainer_results.metrics["validation_metric"]
Expand Down
11 changes: 10 additions & 1 deletion ludwig/distributed/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union

import torch
from torch import nn
Expand Down Expand Up @@ -182,6 +182,15 @@ def create_checkpoint_handle(

return MultiNodeCheckpoint(self, model, optimizer, scheduler)

@classmethod
def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]:
return model

@classmethod
def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module:
assert isinstance(state, nn.Module)
return state


class LocalStrategy(DistributedStrategy):
def prepare(
Expand Down
14 changes: 13 additions & 1 deletion ludwig/distributed/deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import warnings
from typing import Any, Dict, Mapping, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

import deepspeed
import deepspeed.comm
Expand All @@ -14,6 +14,7 @@
from ludwig.distributed.ddp import DDPStrategy
from ludwig.modules.optimization_modules import get_optimizer_class_and_kwargs
from ludwig.utils.checkpoint_utils import Checkpoint
from ludwig.utils.model_utils import extract_tensors, replace_tensors

if TYPE_CHECKING:
from ludwig.modules.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -219,3 +220,14 @@ def get_state_for_inference(self, save_path: str, device: Optional[torch.device]
save_path, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True
)
return self.model.module.cpu().state_dict()

@classmethod
def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]:
return extract_tensors(model)

@classmethod
def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module:
assert isinstance(state, tuple)
model, model_weights = state
replace_tensors(model, model_weights, torch.device("cpu"))
return model
24 changes: 14 additions & 10 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,20 @@ def __init__(
self._random_seed = random_seed

self.model_name = self.config_obj.base_model
self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model)

logger.info("Loading large language model...")
self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model)
self.curr_device = torch.device("cpu") # model initially loaded onto cpu

# Model initially loaded onto cpu
self.curr_device = torch.device("cpu")
logger.info("Done.")

# Determines the maximum length of the context (input + output tokens)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
if hasattr(self.model_config, "max_sequence_length"):
self.context_len = self.model_config.max_sequence_length
elif hasattr(self.model_config, "max_position_embeddings"):
self.context_len = self.model_config.max_position_embeddings
else:
self.context_len = 2048

Expand All @@ -121,7 +124,7 @@ def __init__(

# Initialize tokenizer
use_fast = True
if isinstance(AutoConfig.from_pretrained(self.config_obj.base_model), LlamaConfig):
if isinstance(self.model_config, LlamaConfig):
# HACK: Llama fast tokenizer takes about 2-4 minutes to load, so we disable it for now.
use_fast = False
self.tokenizer = AutoTokenizer.from_pretrained(self.config_obj.base_model, use_fast=use_fast)
Expand Down Expand Up @@ -152,16 +155,13 @@ def __init__(
# because the model has additional "head" layers that are used to predict the next
# token in the sequence. These head layers can add additional dimensions to the
# logits tensor, beyond the vocab_size dimension.
input_size=self.input_shape[-1] if self.output_feature_type == TEXT else self.model.config.vocab_size,
input_size=self.input_shape[-1] if self.output_feature_type == TEXT else self.model_config.vocab_size,
)
)

# Extract the decoder object for the forward pass
self._output_feature_decoder = ModuleWrapper(self.output_features.items()[0][1])

# Initialize the PEFT adapter is one is provided
self.initialize_adapter()

clear_data_cache()

def create_feature_dict(self) -> LudwigFeatureDict:
Expand Down Expand Up @@ -193,6 +193,10 @@ def initialize_adapter(self):
self.model.print_trainable_parameters()
logger.info("==================================================")

def prepare_for_training(self):
# TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this.
self.initialize_adapter()

def to_device(self, device):
device = torch.device(device)

Expand Down
1 change: 1 addition & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
self.base_learning_rate = base_learning_rate

self.model = model
self.model.prepare_for_training()
self.model = self.distributed.to_device(self.model)
self.model.metrics_to_device(self.device)

Expand Down
76 changes: 76 additions & 0 deletions ludwig/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch

NUMPY_TO_TORCH_DTYPE = {
bool: torch.bool,
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}


def extract_tensors(model: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]:
"""Remove the tensors from a PyTorch model, convert them to NumPy arrays, and return the stripped model and
tensors.
Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with-
ray-8be751a6944c # noqa
"""

tensors = []
for _, module in model.named_modules():
# Store the tensors as numpy arrays in Python dictionaries
# Delete the same tensors since we no longer need them and we want to reduce memory pressure.
# This ensures that throughout this process, we keep memory nearly linear w.r.t model parameters.
params = OrderedDict()
buffers = OrderedDict()
for name, param in module.named_parameters(recurse=False):
params[name] = torch.clone(param).detach().numpy()
del param
for name, buf in module.named_buffers(recurse=False):
buffers[name] = torch.clone(buf).detach().numpy()
del buf
tensors.append({"params": params, "buffers": buffers})

# Strip all tensors and buffers out of the original model.
for _, module in model.named_modules():
for name in [name for name, _ in module.named_parameters(recurse=False)] + [
name for name, _ in module.named_buffers(recurse=False)
]:
setattr(module, name, None)

return model, tensors


def replace_tensors(m: torch.nn.Module, tensors: List[Dict], device: torch.device):
"""Restore the tensors that extract_tensors() stripped out of a PyTorch model. This operation is performed in
place.
Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with-
ray-8be751a6944c # noqa
"""
modules = [module for _, module in m.named_modules()]
for module, tensor_dict in zip(modules, tensors):
# There are separate APIs to set parameters and buffers.
for name, array in tensor_dict["params"].items():
module.register_parameter(
name,
torch.nn.Parameter(torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype))),
)

for name, array in tensor_dict["buffers"].items():
module.register_buffer(
name,
torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype)),
)
4 changes: 4 additions & 0 deletions ludwig/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def __init__(self):
def device(self):
return self.device_tensor.device

def prepare_for_training(self):
"""This is called from within the Trainer object to do any final instantiation before model training."""
pass

def losses(self):
collected_losses = []
for loss in self._losses.values():
Expand Down
10 changes: 7 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ torch>=1.13.0
torchaudio
torchtext
torchvision
transformers>=4.28.1
transformers>=4.31.0
tokenizers>=0.13.3
spacy>=2.3
PyYAML>=3.12

# https://github.com/yaml/pyyaml/issues/601
PyYAML>=3.12,<6.0.1

absl-py
kaggle
requests
Expand All @@ -24,7 +28,7 @@ marshmallow
marshmallow-jsonschema
marshmallow-dataclass==8.5.4
tensorboard
torchmetrics<=0.11.4
torchmetrics>=0.11.0,<=0.11.4
torchinfo
filelock
psutil==5.9.4
Expand Down
Loading

0 comments on commit 8696a72

Please sign in to comment.