From 8696a72730a231808e43c3ed4f5e18359e3bed3b Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Sat, 22 Jul 2023 01:57:41 -0700 Subject: [PATCH] Zero copy initialization of models onto training workers for LLMs (#3469) Co-authored-by: Geoffrey Angus Co-authored-by: Travis Addair --- ludwig/api.py | 4 +- ludwig/backend/ray.py | 24 +++++++- ludwig/distributed/base.py | 11 +++- ludwig/distributed/deepspeed.py | 14 ++++- ludwig/models/llm.py | 24 ++++---- ludwig/trainers/trainer.py | 1 + ludwig/utils/model_utils.py | 76 ++++++++++++++++++++++++++ ludwig/utils/torch_utils.py | 4 ++ requirements.txt | 10 +++- tests/integration_tests/test_llm.py | 16 ++++-- tests/ludwig/utils/test_model_utils.py | 61 +++++++++++++++++++++ 11 files changed, 221 insertions(+), 24 deletions(-) create mode 100644 ludwig/utils/model_utils.py create mode 100644 tests/ludwig/utils/test_model_utils.py diff --git a/ludwig/api.py b/ludwig/api.py index 700dd854b55..6e0230c347e 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -638,12 +638,13 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): test_set=test_set, save_path=model_dir, ) + (self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats # Calibrates output feature probabilities on validation set if calibration is enabled. # Must be done after training, and before final model parameters are saved. if self.backend.is_coordinator(): calibrator = Calibrator( - trainer.model, + self.model, self.backend, batch_size=trainer.eval_batch_size, ) @@ -684,7 +685,6 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): # List[TrainerMetric], with one entry per training checkpoint, according to steps_per_checkpoint. # We reduce the dictionary of TrainerMetrics to a simple list of floats for interfacing with Ray # Tune. - (self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats train_stats = TrainingStats( metric_utils.reduce_trainer_metrics_dict(train_trainset_stats), metric_utils.reduce_trainer_metrics_dict(train_valiset_stats), diff --git a/ludwig/backend/ray.py b/ludwig/backend/ray.py index 53d95f6b685..910e4efa02c 100644 --- a/ludwig/backend/ray.py +++ b/ludwig/backend/ray.py @@ -194,7 +194,10 @@ def train_fn( if test_shard is not None: test_shard = RayDatasetShard(test_shard, features, training_set_metadata) - model = ray.get(model_ref) + # Deserialize the model (minus weights) from Plasma + # Extract the weights from Plasma (without copying data) + # Load the weights back into the model in-place on the current device (CPU) + model = distributed.replace_model_from_serialization(ray.get(model_ref)) model = distributed.to_device(model) trainer = remote_trainer_cls( @@ -339,6 +342,7 @@ def __init__(self, trainer_kwargs: Dict[str, Any]) -> None: trainer_kwargs = copy.copy(trainer_kwargs) self.backend_config = trainer_kwargs.pop("backend", None) self.strategy = trainer_kwargs.pop("strategy", get_default_strategy_name()) + self.dist_strategy = get_dist_strategy(self.strategy) if "max_retries" in trainer_kwargs: logger.warning("`max_retries` is no longer supported as a trainer argument in Ray backend. Ignoring it.") @@ -408,7 +412,7 @@ def run( callbacks = callbacks or [] - trainer_cls, kwargs = get_dist_strategy(self.strategy).get_trainer_cls(self.backend_config) + trainer_cls, kwargs = self.dist_strategy.get_trainer_cls(self.backend_config) train_loop_config = {**config, "distributed_strategy": self.strategy} trainer = trainer_cls( train_loop_per_worker=train_loop_per_worker, @@ -475,11 +479,18 @@ def train( stream_window_size["test"] = test_set.window_size_bytes with create_runner(**self.trainer_kwargs) as runner: + # Extract weights as numpy tensors and place them in the Ray object store. + # If we store the weights of a model as NumPy arrays on Plasma, we can access those + # weights directly out of Plasma’s shared memory segments, without making any copies. + # This enables zero copy model loading on each training worker using shared + # memory from the Ray object store for model initialization. + dist_strategy = runner.dist_strategy + model_ref = ray.put(dist_strategy.extract_model_for_serialization(self.model)) trainer_results = runner.run( lambda config: train_fn(**config), config={ "executable_kwargs": executable_kwargs, - "model_ref": ray.put(self.model), + "model_ref": model_ref, "remote_trainer_cls": self.remote_trainer_cls, **kwargs, }, @@ -489,6 +500,13 @@ def train( stream_window_size=stream_window_size, ) + # re-register the weights of the model object in the main process + self.model = dist_strategy.replace_model_from_serialization(ray.get(model_ref)) + + # ensure module is initialized exactly as it is in the trainer process + # so that the state dict can be loaded back into the model correctly. + self.model.prepare_for_training() + # Set validation field and metric used by trainer self._validation_field = trainer_results.metrics["validation_field"] self._validation_metric = trainer_results.metrics["validation_metric"] diff --git a/ludwig/distributed/base.py b/ludwig/distributed/base.py index b08648169bd..cbe1de0fd5b 100644 --- a/ludwig/distributed/base.py +++ b/ludwig/distributed/base.py @@ -1,6 +1,6 @@ import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Tuple, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union import torch from torch import nn @@ -182,6 +182,15 @@ def create_checkpoint_handle( return MultiNodeCheckpoint(self, model, optimizer, scheduler) + @classmethod + def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]: + return model + + @classmethod + def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module: + assert isinstance(state, nn.Module) + return state + class LocalStrategy(DistributedStrategy): def prepare( diff --git a/ludwig/distributed/deepspeed.py b/ludwig/distributed/deepspeed.py index 3a80a572590..f92577f1753 100644 --- a/ludwig/distributed/deepspeed.py +++ b/ludwig/distributed/deepspeed.py @@ -1,7 +1,7 @@ import logging import os import warnings -from typing import Any, Dict, Mapping, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union import deepspeed import deepspeed.comm @@ -14,6 +14,7 @@ from ludwig.distributed.ddp import DDPStrategy from ludwig.modules.optimization_modules import get_optimizer_class_and_kwargs from ludwig.utils.checkpoint_utils import Checkpoint +from ludwig.utils.model_utils import extract_tensors, replace_tensors if TYPE_CHECKING: from ludwig.modules.lr_scheduler import LRScheduler @@ -219,3 +220,14 @@ def get_state_for_inference(self, save_path: str, device: Optional[torch.device] save_path, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True ) return self.model.module.cpu().state_dict() + + @classmethod + def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]: + return extract_tensors(model) + + @classmethod + def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module: + assert isinstance(state, tuple) + model, model_weights = state + replace_tensors(model, model_weights, torch.device("cpu")) + return model diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index c3d0e8fe2e4..06c25c19693 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -89,17 +89,20 @@ def __init__( self._random_seed = random_seed self.model_name = self.config_obj.base_model + self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model) logger.info("Loading large language model...") self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model) - self.curr_device = torch.device("cpu") # model initially loaded onto cpu + + # Model initially loaded onto cpu + self.curr_device = torch.device("cpu") logger.info("Done.") # Determines the maximum length of the context (input + output tokens) - if hasattr(self.model.config, "max_sequence_length"): - self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model.config, "max_position_embeddings"): - self.context_len = self.model.config.max_position_embeddings + if hasattr(self.model_config, "max_sequence_length"): + self.context_len = self.model_config.max_sequence_length + elif hasattr(self.model_config, "max_position_embeddings"): + self.context_len = self.model_config.max_position_embeddings else: self.context_len = 2048 @@ -121,7 +124,7 @@ def __init__( # Initialize tokenizer use_fast = True - if isinstance(AutoConfig.from_pretrained(self.config_obj.base_model), LlamaConfig): + if isinstance(self.model_config, LlamaConfig): # HACK: Llama fast tokenizer takes about 2-4 minutes to load, so we disable it for now. use_fast = False self.tokenizer = AutoTokenizer.from_pretrained(self.config_obj.base_model, use_fast=use_fast) @@ -152,16 +155,13 @@ def __init__( # because the model has additional "head" layers that are used to predict the next # token in the sequence. These head layers can add additional dimensions to the # logits tensor, beyond the vocab_size dimension. - input_size=self.input_shape[-1] if self.output_feature_type == TEXT else self.model.config.vocab_size, + input_size=self.input_shape[-1] if self.output_feature_type == TEXT else self.model_config.vocab_size, ) ) # Extract the decoder object for the forward pass self._output_feature_decoder = ModuleWrapper(self.output_features.items()[0][1]) - # Initialize the PEFT adapter is one is provided - self.initialize_adapter() - clear_data_cache() def create_feature_dict(self) -> LudwigFeatureDict: @@ -193,6 +193,10 @@ def initialize_adapter(self): self.model.print_trainable_parameters() logger.info("==================================================") + def prepare_for_training(self): + # TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this. + self.initialize_adapter() + def to_device(self, device): device = torch.device(device) diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index c117e8ae23a..5c806e83779 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -175,6 +175,7 @@ def __init__( self.base_learning_rate = base_learning_rate self.model = model + self.model.prepare_for_training() self.model = self.distributed.to_device(self.model) self.model.metrics_to_device(self.device) diff --git a/ludwig/utils/model_utils.py b/ludwig/utils/model_utils.py new file mode 100644 index 00000000000..f4a879859c4 --- /dev/null +++ b/ludwig/utils/model_utils.py @@ -0,0 +1,76 @@ +from collections import OrderedDict +from typing import Dict, List, Tuple + +import numpy as np +import torch + +NUMPY_TO_TORCH_DTYPE = { + bool: torch.bool, + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +def extract_tensors(model: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: + """Remove the tensors from a PyTorch model, convert them to NumPy arrays, and return the stripped model and + tensors. + + Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with- + ray-8be751a6944c # noqa + """ + + tensors = [] + for _, module in model.named_modules(): + # Store the tensors as numpy arrays in Python dictionaries + # Delete the same tensors since we no longer need them and we want to reduce memory pressure. + # This ensures that throughout this process, we keep memory nearly linear w.r.t model parameters. + params = OrderedDict() + buffers = OrderedDict() + for name, param in module.named_parameters(recurse=False): + params[name] = torch.clone(param).detach().numpy() + del param + for name, buf in module.named_buffers(recurse=False): + buffers[name] = torch.clone(buf).detach().numpy() + del buf + tensors.append({"params": params, "buffers": buffers}) + + # Strip all tensors and buffers out of the original model. + for _, module in model.named_modules(): + for name in [name for name, _ in module.named_parameters(recurse=False)] + [ + name for name, _ in module.named_buffers(recurse=False) + ]: + setattr(module, name, None) + + return model, tensors + + +def replace_tensors(m: torch.nn.Module, tensors: List[Dict], device: torch.device): + """Restore the tensors that extract_tensors() stripped out of a PyTorch model. This operation is performed in + place. + + Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with- + ray-8be751a6944c # noqa + """ + modules = [module for _, module in m.named_modules()] + for module, tensor_dict in zip(modules, tensors): + # There are separate APIs to set parameters and buffers. + for name, array in tensor_dict["params"].items(): + module.register_parameter( + name, + torch.nn.Parameter(torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype))), + ) + + for name, array in tensor_dict["buffers"].items(): + module.register_buffer( + name, + torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype)), + ) diff --git a/ludwig/utils/torch_utils.py b/ludwig/utils/torch_utils.py index e0d50a7deac..10be7c762c8 100644 --- a/ludwig/utils/torch_utils.py +++ b/ludwig/utils/torch_utils.py @@ -183,6 +183,10 @@ def __init__(self): def device(self): return self.device_tensor.device + def prepare_for_training(self): + """This is called from within the Trainer object to do any final instantiation before model training.""" + pass + def losses(self): collected_losses = [] for loss in self._losses.values(): diff --git a/requirements.txt b/requirements.txt index ebc726238a6..2b7af17395c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,9 +10,13 @@ torch>=1.13.0 torchaudio torchtext torchvision -transformers>=4.28.1 +transformers>=4.31.0 +tokenizers>=0.13.3 spacy>=2.3 -PyYAML>=3.12 + +# https://github.com/yaml/pyyaml/issues/601 +PyYAML>=3.12,<6.0.1 + absl-py kaggle requests @@ -24,7 +28,7 @@ marshmallow marshmallow-jsonschema marshmallow-dataclass==8.5.4 tensorboard -torchmetrics<=0.11.4 +torchmetrics>=0.11.0,<=0.11.4 torchinfo filelock psutil==5.9.4 diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 5e44753c54e..3c5108f556d 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -8,6 +8,7 @@ from ludwig.api import LudwigModel from ludwig.constants import ( ADAPTER, + BACKEND, BASE_MODEL, BATCH_SIZE, EPOCHS, @@ -116,9 +117,10 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): GENERATION: get_generation_config(), INPUT_FEATURES: input_features, OUTPUT_FEATURES: output_features, + BACKEND: backend, } - model = LudwigModel(config, backend=backend) + model = LudwigModel(config) model.train(dataset=dataset_filename, output_directory=str(tmpdir), skip_save_processed_input=True) preds, _ = model.predict(dataset=dataset_filename, output_directory=str(tmpdir), split="test") @@ -175,9 +177,10 @@ def test_llm_zero_shot_classification(tmpdir, backend, ray_cluster_4cpu): PROMPT: {"task": "This is a review of a restaurant. Classify the sentiment."}, INPUT_FEATURES: input_features, OUTPUT_FEATURES: output_features, + BACKEND: backend, } - model = LudwigModel(config, backend=backend) + model = LudwigModel(config) model.train(dataset=df, output_directory=str(tmpdir), skip_save_processed_input=True) prediction_df = pd.DataFrame( @@ -246,6 +249,7 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_ PREPROCESSING: { "split": {TYPE: "fixed"}, }, + BACKEND: {**backend, "cache_dir": str(tmpdir)}, } dataset_path = generate_data( @@ -260,7 +264,7 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_ df["output"] = np.random.choice([1, 2, 3, 4, 5], size=len(df)).astype(str) # ensure labels match the feature config df.to_csv(dataset_path, index=False) - model = LudwigModel(config, backend={**backend, "cache_dir": str(tmpdir)}) + model = LudwigModel(config) model.train(dataset=dataset_path, output_directory=str(tmpdir), skip_save_processed_input=True) # TODO: fix LLM model loading @@ -351,6 +355,7 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat BATCH_SIZE: 8, EPOCHS: 2, }, + BACKEND: backend, } if finetune_strategy is not None: @@ -359,7 +364,7 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat **adapter_args, } - model = LudwigModel(config, backend=backend) + model = LudwigModel(config) model.train(dataset=df, output_directory=str(tmpdir), skip_save_processed_input=False) prediction_df = pd.DataFrame( @@ -408,6 +413,9 @@ def test_lora_wrap_on_init(): } config_obj = ModelConfig.from_dict(config) model = LLM(config_obj) + # We need to explicitly make this call since we now load the adapter + # in the trainer as opposed to the point of LLM model initialization. + model.prepare_for_training() assert not isinstance(model.model, PreTrainedModel) assert isinstance(model.model, PeftModel) diff --git a/tests/ludwig/utils/test_model_utils.py b/tests/ludwig/utils/test_model_utils.py new file mode 100644 index 00000000000..ea6374979d4 --- /dev/null +++ b/tests/ludwig/utils/test_model_utils.py @@ -0,0 +1,61 @@ +import torch + +from ludwig.utils.model_utils import extract_tensors, replace_tensors + +# Define a sample model for testing + + +class SampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.relu = torch.nn.ReLU() + + +def test_extract_tensors(): + # Create a sample model + model = SampleModel() + + # Call extract_tensors function + stripped_model, tensors = extract_tensors(model) + + # Assert that the model and tensors are returned + assert isinstance(stripped_model, torch.nn.Module) + assert isinstance(tensors, list) + + # Assert that the tensors contain the expected keys + for tensor_dict in tensors: + assert "params" in tensor_dict + assert "buffers" in tensor_dict + + # Assert that all model parameters are set to None + for module in stripped_model.modules(): + for name, param in module.named_parameters(recurse=False): + assert param is None + + for name, buf in module.named_buffers(recurse=False): + assert buf is None + + +def test_replace_tensors(): + # Create a sample model + model = SampleModel() + + # Call extract_tensors function to get the tensors + _, tensors = extract_tensors(model) + + # Create a new device for testing + device = torch.device("cpu") + + # Call replace_tensors function + replace_tensors(model, tensors, device) + + # Assert that the tensors are restored + for module, tensor_dict in zip(model.modules(), tensors): + for name, array in tensor_dict["params"].items(): + assert name in module._parameters + assert torch.allclose(module._parameters[name], torch.as_tensor(array, device=device)) + + for name, array in tensor_dict["buffers"].items(): + assert name in module._buffers + assert torch.allclose(module._buffers[name], torch.as_tensor(array, device=device))