From 98b284b9b1213e9749df7d5b95323dd92cd2f98a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 04:03:55 +0000 Subject: [PATCH 01/59] WIP --- examples/quantization_w4a16/llama3_example.py | 1 + .../modifiers/quantization/gptq/base.py | 235 ++++++--------- .../quantization/gptq/utils/compress.py | 278 ++++++++++++++++++ src/llmcompressor/utils/helpers.py | 45 +++ 4 files changed, 419 insertions(+), 140 deletions(-) create mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/compress.py diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 939991ab6..d587a6199 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,6 +6,7 @@ # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +#MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b472e289e..fd1c42116 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch +from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, disable_quantization, @@ -21,6 +22,7 @@ from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.context import fix_fsdp_module_name +from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, getattr_chain from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -109,11 +111,6 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None - model: Optional[Any] = None - layer_compressors_: Optional[List[Any]] = None - compressible_layers_: Optional[List] = None - quantization_modifier_: Any = None - @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: if not value: @@ -124,6 +121,13 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return value + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.current_layer_index = 0 + self.num_layers = 0 + self.quantization_modifier_ = None def on_initialize_structure(self, state: State, **kwargs): """ @@ -191,20 +195,29 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - modifiable_model = state.model - calibration_dataloader = state.data.calib - + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) if self.sequential_targets is None: - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - self.sequential_targets = get_no_split_params(modifiable_model) + self.sequential_targets = get_no_split_params(state.model) + layers = get_layers(self.sequential_targets, state.model) + self.num_layers = len(layers) + + # add hooks to targets and layers + self.register_hooks(state.model, layers) + + # apply calibration and trigger hooks (hooks are self removing) + self.calibration_forward(state.model, state.data.calib) - self.initialize_compression(modifiable_model, calibration_dataloader) - self.apply_compression(calibration_dataloader) + # freeze quantization state.model.apply(freeze_module_quantization) return True + + def on_end(self): + self.register_hooks(state.model, layers) + self.dummy_forward() ??? def on_finalize(self, state: "State", **kwargs) -> bool: """ @@ -216,121 +229,80 @@ def on_finalize(self, state: "State", **kwargs) -> bool: self.quantization_modifier_.finalize(state, **kwargs) return True - - def compressible_layers(self) -> Dict: - """ - Retrieves the modules corresponding to a list of - compressible layer names - - :precondition: self.model is set and is a torch.nn.Module - :return: dictionary of modules to compress - """ - if not isinstance(self.model, Module): - raise ValueError( - "`self.model` must be a torch.nn.Module to use " - f"the {self.__class__.__qualname__} modifier but got " - f"{type(self.model)} instead" + + def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): + layers = layers.values() + + for name, module in model.named_modules(): + quant_args = getattr_chain(module, "quantization_scheme.weights", None) + if quant_args is not None: + module._gptq_pre_hook = module.register_forward_pre_hook( + partial(self.target_pre_forward, name, quant_args)) + module._gptq_post_hook = module.register_forward_hook( + partial(self.target_post_forward, name, quant_args)) + + if module in layers.values(): + module._gptq_pre_hook = module.register_forward_pre_hook( + partial(self.layer_pre_forward, name)) + module._gptq_post_hook = module.register_forward_hook( + partial(self.layer_post_forward, name)) + + def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dataloader): + all_data = torch.cat([batch for batch in data], dim=0) + with DisableKVCache(model), DisableQuantization(model): + model(all_data) + + def target_pre_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs): + if self.true_sequential: + # compress first so output is from quantized weights + logger.info(f"Compressing {name}...") + gptq_compress( + module, + args, + kwargs, + quant_args, + block_size=self.block_size, + percdamp=self.dampening_frac, ) + + def target_post_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs, output): + if not self.true_sequential: + # compress after so output is from unquantized weights + logger.info(f"Compressing {name}...") + gptq_compress( + module, + args, + kwargs, + quant_args, + block_size=self.block_size, + percdamp=self.dampening_frac, + ) + + def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): + logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") + + def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): + self.remove_hooks(module) - return get_layers(self.sequential_targets, self.model) - - def initialize_compression( - self, - model: Module, - dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, - ): - """ - Setup for GPTQ, initializes the model - and other parameters, also initilializes the - compressible layers of model, and sets the device - - :param model: model to initialize for compression - :param dataloader: calibration data for GPTQ - """ - self.model = model - self.compressible_layers_ = self.compressible_layers() - self.layer_compressors_ = [] - - for idx, (name, layer) in enumerate(self.compressible_layers_.items()): - name = fix_fsdp_module_name(name) - logger.info(f"Preparing {name} for compression") - args = self._pruning_arguments() - comp_cls = self._compression_class() - compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) - - # if running sequentially, allocate all hessians now - if not self.sequential_update: - compressor.pre_compress() - - self.layer_compressors_.append(compressor) - - if self.sequential_update: - first_layer_compressor = self.layer_compressors_[0] - first_layer_compressor.set_early_stop() - - @torch.no_grad() - def apply_compression( - self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None - ) -> Dict: - """ - Run GPTQ on the loaded model, using dataloader as calibration data - - :param dataloader: calibration data for GPTQ - """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " f"{len(dataloader)} samples..." - ) - - # quantization scales and zp are already initialized but we do not - # want to calibrate wrt to these - self.model.apply(disable_quantization) - - forward_pass_use_cache = self.model.config.use_cache - self.model.config.use_cache = False - - # in non-sequential mode we run calibration through the full model - # in sequential mode we run calibration up to the first transformer target - intermediates = run_calibration_forward( - self.model, dataloader, mask_padding=True - ) - self.layer_compressors_[0].clear_early_stop() - - # empty cache if not using sequential update - if not self.sequential_update: - del intermediates - gc.collect() - torch.cuda.empty_cache() - - num_layers = len(self.compressible_layers_) - for idx, layer_compressor in enumerate(self.layer_compressors_): - logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") - - if self.sequential_update: - # in sequential mode we run the forward pass for each transformer layer - # one at a time, caching the intermediate outputs between layers - logger.info(f"Calibrating {layer_compressor.name}...") - layer_compressor.pre_compress() - unquantized_outputs = layer_compressor.calibrate_layer(intermediates) - - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() + if not self.true_sequential: + # rerun with (now) quantized weights + output = module(*args, **kwargs) - if self.sequential_update: - quantized_outputs = layer_compressor.calibrate_layer(intermediates) - error = get_output_error(unquantized_outputs, quantized_outputs) - logger.info(f"Mean output error from quantization: {error:.3f}") - intermediates = quantized_outputs - del unquantized_outputs + self.layer_index += 1 + return output - gc.collect() - torch.cuda.empty_cache() + def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): + if hasattr(module, "_gptq_pre_hook"): + module._gptq_pre_hook.remove() + delattr(module, "_gptq_pre_hook") - self.model.config.use_cache = forward_pass_use_cache + if hasattr(module, "_gptq_post_hook"): + module._gptq_post_hook.remove() + delattr(module, "_gptq_post_hook") - # re-enable quantization - self.model.apply(enable_quantization) + if recurse: + for child_module in module.children(): + self.remove_hooks(child_module) def _build_quant_modifier(self): """ @@ -369,20 +341,3 @@ def _build_quant_modifier_from_dict(self, quant_config): allow_experimental=True, **modifier_args, ) - - def _pruning_arguments(self): - """ - Gather the parameters needed for root module compression in a dict - - :return: dict of params for pruning - """ - return { - "blocksize": self.block_size, - "percdamp": self.dampening_frac, - } - - def _compression_class(self): - """ - :return: wrapper class used for root modules of this compression class - """ - return GPTQWrapper diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py b/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py new file mode 100644 index 000000000..05a856072 --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py @@ -0,0 +1,278 @@ +import torch + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + """ + Add a batch of layer input and output data to the Hessian calculation + + :param inp: tensor containing layer input + :param out: tensor containing layer output + """ + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = inp.to(dtype=self.H.dtype) + inp = math.sqrt(2 / self.nsamples) * inp + self.H += inp.matmul(inp.t()) + + def compress( + self, + blocksize: int = 128, + percdamp: float = 0.01, + ): + """ + Run pruning and quantization(if applicable) on the layer up to the target + sparsity value. + + :param blocksize: Number of columns to compress in one pass + :param percdamp: Amount of dampening to apply to H, as a fraction of the + diagonal norm + """ + args_loc = "quantization_scheme.weights" + weight_quant_args = getattr_chain(self.layer, args_loc, None) + if weight_quant_args is None: + logger.debug(f"Skipping unquantized layer {self.name}...") + return + + if is_module_offloaded(self.layer): + self.layer._hf_hook.pre_forward(self.layer) + + strategy = weight_quant_args.strategy + actorder = weight_quant_args.actorder + final_shape = self.layer.weight.shape + final_dtype = self.layer.weight.dtype + W = self.layer.weight.data.clone() + + # standardize shape and dtype + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + elif isinstance(self.layer, transformers.Conv1D): + W.transpose_(0, 1) + W = W.float() + + tick = time.time() + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(self.columns, device=W.device, dtype=torch.int) + // weight_quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, self.H, perm = self._apply_activation_ordering(W, self.H) + self._update_quantization_parameters(weight_quant_args, W) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + self._update_quantization_parameters(weight_quant_args, W) + W, self.H, perm = self._apply_activation_ordering(W, self.H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + scale = self.layer.weight_scale + zero_point = self.layer.weight_zero_point + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + + # mask dead hessian values + dead = torch.diag(self.H) == 0 + self.H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=self.dev) + + # compute inverse hessian in place to save memory + damp = percdamp * torch.mean(torch.diag(self.H)) + diag = torch.arange(self.columns, device=self.dev) + self.H[diag, diag] += damp + self.H = torch.linalg.cholesky(self.H) + self.H = torch.cholesky_inverse(self.H) + self.H = torch.linalg.cholesky(self.H, upper=True) + Hinv = self.H + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + + # quantize column + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + self.layer.quantization_scheme.weights, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + weight_quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + group_index = g_idx[column_idx] + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(weight_quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, group_index], + zero_point[:, group_index], + altered_qargs, + ) + else: + raise ValueError( + "Quantization strategy is not supported for GPTQ: " + f"{strategy}" + ) + + # propagate column error + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err + Err1[:, i] = err1 + + # propagate block error + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err + + if "METRIC" in logger._core.levels.keys(): + self._log_metrics(tick, Losses) + + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + update_parameter_data(self.layer, g_idx, "weight_g_idx") + + if isinstance(self.layer, transformers.Conv1D): + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + # This is a bit hacky, but FSDP updates only work if we change + # the weight in place, clone() or direct assignment won't work + self.layer.weight -= self.layer.weight + self.layer.weight += W + + if is_module_offloaded(self.layer): + device = get_offloaded_device(self.layer) + update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) + self.layer._hf_hook.post_forward(self.layer, None) + + def free(self): + """ + Free the Hessian memory after the layer is complete + """ + delattr(self, "H") + super().free() + + def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): + """ + Update layer quantization parameters with potentially permuted weight + + :param args: quantization arguments + :param W: weight to calculate quantization parameters from + """ + observer = args.get_observer() + _scale, _zero_point = observer(W, g_idx=None) + update_parameter_data(self.layer, _scale, "weight_scale") + update_parameter_data(self.layer, _zero_point, "weight_zero_point") + + def _apply_activation_ordering( + self, W: torch.Tensor, H: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Permute weight and hessian in order of greatest outupt activations + + :param W: weight to permute + """ + perm = torch.argsort(torch.diag(H), descending=True) + return W[:, perm], H[perm][:, perm], perm + + def _log_metrics(self, start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) + patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + patch.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 266acf973..0305c04df 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -22,8 +22,14 @@ from urllib.parse import urlparse import numpy +import torch from loguru import logger +from compressed_tensors.quantization import ( + disable_quantization, + enable_quantization, +) + __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -59,6 +65,7 @@ "is_package_available", "import_from_path", "getattr_chain", + "DisableKVCache", ] @@ -1041,3 +1048,41 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: res = getattr(res, attr_name) return res + + +class DisableKVCache: + def __init__(self, model: torch.nn.Module): + if hasattr(model.config, "use_cache"): + self.config = model.config + + # MllamaConfig + elif hasattr(model.config, "text_config") and hasattr( + model.config.text_config, "use_cache" + ): + self.config = model.config.text_config + + # unknown config structure + else: + raise NotImplementedError( + f"Cannot find `use_cache` for config of type {type(model.config)}" + ) + + self.restore_value = self.config.use_cache + + def __enter__(self): + self.restore_value = self.config.use_cache + self.config.use_cache = False + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.config.use_cache = self.restore_value + + +class DisableQuantization: + def __init__(self, model: torch.nn.Module): + self.model = model + + def __enter__(self): + self.model.apply(disable_quantization) + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.model.apply(enable_quantization) \ No newline at end of file From e3a98cc12c6840fb4836d71d1b77d934512a46f0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 17:08:33 +0000 Subject: [PATCH 02/59] WIP: begin quantize_weight --- .../modifiers/quantization/gptq/base.py | 69 +++---- .../utils/{compress.py => gptq_quantize.py} | 172 +++++++++--------- 2 files changed, 118 insertions(+), 123 deletions(-) rename src/llmcompressor/modifiers/quantization/gptq/utils/{compress.py => gptq_quantize.py} (66%) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index fd1c42116..9e9c31d6e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -215,10 +215,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: return True - def on_end(self): - self.register_hooks(state.model, layers) - self.dummy_forward() ??? - def on_finalize(self, state: "State", **kwargs) -> bool: """ disable the quantization observers used by the OBCQ algorithm @@ -234,49 +230,32 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu layers = layers.values() for name, module in model.named_modules(): - quant_args = getattr_chain(module, "quantization_scheme.weights", None) - if quant_args is not None: - module._gptq_pre_hook = module.register_forward_pre_hook( - partial(self.target_pre_forward, name, quant_args)) - module._gptq_post_hook = module.register_forward_hook( - partial(self.target_post_forward, name, quant_args)) + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook) if module in layers.values(): - module._gptq_pre_hook = module.register_forward_pre_hook( - partial(self.layer_pre_forward, name)) - module._gptq_post_hook = module.register_forward_hook( - partial(self.layer_post_forward, name)) + pre_hook = partial(self.layer_pre_forward, name) + post_hook = partial(self.layer_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook) def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dataloader): all_data = torch.cat([batch for batch in data], dim=0) with DisableKVCache(model), DisableQuantization(model): model(all_data) - def target_pre_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs): + def target_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): if self.true_sequential: # compress first so output is from quantized weights - logger.info(f"Compressing {name}...") - gptq_compress( - module, - args, - kwargs, - quant_args, - block_size=self.block_size, - percdamp=self.dampening_frac, - ) + self.quantize_module(name, module, args) - def target_post_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs, output): + def target_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): if not self.true_sequential: # compress after so output is from unquantized weights - logger.info(f"Compressing {name}...") - gptq_compress( - module, - args, - kwargs, - quant_args, - block_size=self.block_size, - percdamp=self.dampening_frac, - ) + self.quantize_module(name, module, args) def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") @@ -291,6 +270,28 @@ def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, o self.layer_index += 1 return output + def quantize_module(self, name, module, inp): + logger.info(f"Compressing {name}...") + + quant_args = getattr_chain(module, "quantization_scheme.weights") + # with onloaded weight + quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + inp, + quant_args, + block_size=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + ) + + # This is a bit hacky, but FSDP updates only work if we change + # the weight in place, clone() or direct assignment won't work + self.layer.weight -= self.layer.weight + self.layer.weight += lerp(module.weight.data, quantized_weight, self.alpha) + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") + def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): if hasattr(module, "_gptq_pre_hook"): module._gptq_pre_hook.remove() diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py similarity index 66% rename from src/llmcompressor/modifiers/quantization/gptq/utils/compress.py rename to src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 05a856072..66f111ffa 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,85 +1,82 @@ -import torch - - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - - :param inp: tensor containing layer input - :param out: tensor containing layer output - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = inp.to(dtype=self.H.dtype) - inp = math.sqrt(2 / self.nsamples) * inp - self.H += inp.matmul(inp.t()) - - def compress( - self, - blocksize: int = 128, - percdamp: float = 0.01, - ): - """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. - - :param blocksize: Number of columns to compress in one pass - :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm - """ - args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(self.layer, args_loc, None) - if weight_quant_args is None: - logger.debug(f"Skipping unquantized layer {self.name}...") - return - - if is_module_offloaded(self.layer): - self.layer._hf_hook.pre_forward(self.layer) - - strategy = weight_quant_args.strategy - actorder = weight_quant_args.actorder - final_shape = self.layer.weight.shape - final_dtype = self.layer.weight.dtype - W = self.layer.weight.data.clone() - - # standardize shape and dtype - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - elif isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.float() +from typing import Any - tick = time.time() - - if strategy == QuantizationStrategy.GROUP: - # mapping from column index to group index - g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size - ) +import time +import math +import torch +from compressed_tensors.quantization import QuantizationArguments + + +def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: + inp = inp.to(device=device) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + if module_class in (torch.nn.Linear, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + nsamples = inp.shape[0] + + inp = inp.to(dtype=torch.float32) + inp = math.sqrt(2 / nsamples) * inp + return inp.matmul(inp.t()) + + +def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + return H + + +def quantize_weight( + weight: torch.Tensor, + inp: torch.Tensor, + quant_args: QuantizationArguments, + block_size: int = 128, + percdamp: float = 0.01, + module_class = torch.nn.Linear, +) -> Tuple[torch.nn.Parameter, ]: + strategy = quant_args.strategy + actorder = quant_args.actorder + final_shape = weight.shape + final_dtype = weight.dtype + W = weight.data.clone() + + # standardize shape and dtype + if module_class == torch.nn.Conv2d: + W = W.flatten(1) + elif module_class == transformers.Conv1D: + W.transpose_(0, 1) + W = W.to(dtype=torch.float32) + + tick = time.time() + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(self.columns, device=W.device, dtype=torch.int) + // weight_quant_args.group_size + ) - if actorder == ActivationOrdering.GROUP: - # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, self.H, perm = self._apply_activation_ordering(W, self.H) + self._update_quantization_parameters(weight_quant_args, W) - # use identity g_idx (invert permutation later) + # use identity g_idx (invert permutation later) - elif actorder == ActivationOrdering.WEIGHT: - # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + self._update_quantization_parameters(weight_quant_args, W) + W, self.H, perm = self._apply_activation_ordering(W, self.H) - # permute g_idx to maintain identity mapping after unpermutation - g_idx = g_idx[perm] + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point @@ -93,21 +90,18 @@ def compress( else None ) - # mask dead hessian values - dead = torch.diag(self.H) == 0 - self.H[dead, dead] = 1 - W[:, dead] = 0 - Losses = torch.zeros(self.rows, device=self.dev) # compute inverse hessian in place to save memory - damp = percdamp * torch.mean(torch.diag(self.H)) - diag = torch.arange(self.columns, device=self.dev) - self.H[diag, diag] += damp - self.H = torch.linalg.cholesky(self.H) - self.H = torch.cholesky_inverse(self.H) - self.H = torch.linalg.cholesky(self.H, upper=True) - Hinv = self.H + H = compute_hessian(inp, module_class, device=device) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # TODO: check in place + Hinv = invert_hessian(H, percdamp) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): From bc9b3bcd889de1557c5fb1868b71a4274d2658f8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 17:20:32 +0000 Subject: [PATCH 03/59] WIP --- .../quantization/gptq/utils/gptq_quantize.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 66f111ffa..f0957b130 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -3,7 +3,8 @@ import time import math import torch -from compressed_tensors.quantization import QuantizationArguments +import transformers +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: @@ -36,8 +37,8 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: def quantize_weight( weight: torch.Tensor, inp: torch.Tensor, - quant_args: QuantizationArguments, - block_size: int = 128, + quant_args: QuantizationArgs, + blocksize: int = 128, percdamp: float = 0.01, module_class = torch.nn.Linear, ) -> Tuple[torch.nn.Parameter, ]: @@ -45,7 +46,10 @@ def quantize_weight( actorder = quant_args.actorder final_shape = weight.shape final_dtype = weight.dtype + num_columns = weight.shape[1] W = weight.data.clone() + + H = compute_hessian(inp, module_class, device=device) # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -56,31 +60,30 @@ def quantize_weight( tick = time.time() + scale, zero_point = compute_scale_zeropoint(W) + if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size ) if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = _update_quantization_parameters(quant_args, W) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) + scale, zero_point = _update_quantization_parameters(quant_args, W) + W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - # sparsity mask sparsity = tensor_sparsity(W) preserve_zeros = sparsity >= SPARSITY_THRESHOLD @@ -90,22 +93,20 @@ def quantize_weight( else None ) - Losses = torch.zeros(self.rows, device=self.dev) - - # compute inverse hessian in place to save memory - H = compute_hessian(inp, module_class, device=device) + Losses = torch.zeros(num_columns, device=weight.device) # mask dead hessian values dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 + # compute inverse hessian in place to save memory # TODO: check in place Hinv = invert_hessian(H, percdamp) # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) count = i2 - i1 W1 = W[:, i1:i2].clone() @@ -128,14 +129,14 @@ def quantize_weight( q, scale, zero_point, - self.layer.quantization_scheme.weights, + quant_args, ) elif strategy == QuantizationStrategy.CHANNEL: q = fake_quantize( q, scale[:, 0], zero_point[:, 0], - weight_quant_args, + quant_args, ) elif strategy == QuantizationStrategy.GROUP: # get the group index for the current column @@ -144,7 +145,7 @@ def quantize_weight( # Since we're only applying quantization to a slice, this # ends up being a channelwise application - altered_qargs = copy(weight_quant_args) + altered_qargs = copy(quant_args) altered_qargs.strategy = QuantizationStrategy.CHANNEL q = fake_quantize( q, From b77c7bf3effbd8b96ee65f5cd2e888a1a9d205a4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 18:55:56 +0000 Subject: [PATCH 04/59] WIP --- .../modifiers/quantization/gptq/base.py | 43 +++++++++---- .../quantization/gptq/utils/gptq_quantize.py | 60 ++++++++++--------- .../quantization/gptq/utils/helpers.py | 12 +++- 3 files changed, 72 insertions(+), 43 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 9e9c31d6e..8bec38fa0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch +import contextlib from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, @@ -18,6 +19,7 @@ from llmcompressor.modifiers.quantization.gptq.utils import ( GPTQWrapper, get_output_error, + gptq_hook ) from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -194,6 +196,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.quantization_modifier_.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + + # after lifecycle refactor, all of this moves to pre_batch # find layers (used for printing even if true_sequential=True) # if no targets are provided, default to the modules that shouldn't be @@ -224,6 +228,8 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) + self.remove_gptq_hooks(state.model) + return True def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): @@ -247,25 +253,28 @@ def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dat with DisableKVCache(model), DisableQuantization(model): model(all_data) + @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): if self.true_sequential: # compress first so output is from quantized weights self.quantize_module(name, module, args) + @gptq_hook def target_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): if not self.true_sequential: # compress after so output is from unquantized weights self.quantize_module(name, module, args) + @gptq_hook def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") + @gptq_hook def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): - self.remove_hooks(module) - if not self.true_sequential: # rerun with (now) quantized weights - output = module(*args, **kwargs) + with self.disable_hooks(): + output = module(*args, **kwargs) self.layer_index += 1 return output @@ -283,16 +292,23 @@ def quantize_module(self, name, module, inp): percdamp=self.dampening_frac, module_class=type(module), ) - - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += lerp(module.weight.data, quantized_weight, self.alpha) - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") - - def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): + + weight = lerp(module.weight.data, quantized_weight, self.alpha) + + update_prefix_dict(self.layer, "weight", weight) + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") + + @contextlib.contextmanager + def disable_hooks(self): + try: + self.hooks_disabled = True + yield + finally: + self.hooks_disabled = False + + def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): if hasattr(module, "_gptq_pre_hook"): module._gptq_pre_hook.remove() delattr(module, "_gptq_pre_hook") @@ -305,6 +321,7 @@ def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): for child_module in module.children(): self.remove_hooks(child_module) + def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index f0957b130..36ab9055f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -182,7 +182,7 @@ def quantize_weight( W[:, i2:] -= w_err if "METRIC" in logger._core.levels.keys(): - self._log_metrics(tick, Losses) + _log_metrics(tick, Losses) if strategy == QuantizationStrategy.GROUP: if actorder == ActivationOrdering.WEIGHT: @@ -213,6 +213,8 @@ def quantize_weight( update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) self.layer._hf_hook.post_forward(self.layer, None) + return W, scale, zero_point, g_idx + def free(self): """ Free the Hessian memory after the layer is complete @@ -243,31 +245,31 @@ def _apply_activation_ordering( perm = torch.argsort(torch.diag(H), descending=True) return W[:, perm], H[perm][:, perm], perm - def _log_metrics(self, start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) +def _log_metrics(self, start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) + patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + patch.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index 58fedc634..f226e41c0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -2,7 +2,7 @@ import torch -__all__ = ["get_output_error"] +__all__ = ["get_output_error", "gptq_hook"] def get_output_error( @@ -49,3 +49,13 @@ def get_output_error( for unq, q in zip(unquantized_outputs, quantized_outputs) ] ) / len(unquantized_outputs) + + +def gptq_hook(func): + def wrapped(self, *args, **kwargs): + if self.hooks_disabled: + return + + func(self, *args, **kwargs) + + return wrapped \ No newline at end of file From 7be5aed7e2996ed4d855ae6f246443784cd43c80 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 20:59:39 +0000 Subject: [PATCH 05/59] wip --- .../modifiers/quantization/gptq/base.py | 9 +++- .../quantization/gptq/utils/gptq_quantize.py | 43 ++++++++++--------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 8bec38fa0..b5da7cce0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -31,6 +31,13 @@ qat_active, ) +from compressed_tensors.utils import ( + get_offloaded_device, + is_module_offloaded, + update_parameter_data, + update_prefix_dict, +) + __all__ = ["GPTQModifier"] @@ -293,7 +300,7 @@ def quantize_module(self, name, module, inp): module_class=type(module), ) - weight = lerp(module.weight.data, quantized_weight, self.alpha) + weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) update_prefix_dict(self.layer, "weight", weight) update_parameter_data(module, scale, "weight_scale") diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 36ab9055f..d5b1efef1 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,10 +1,16 @@ -from typing import Any +from typing import Tuple, Union import time import math import torch import transformers -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering +from copy import copy +from loguru import logger +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering, fake_quantize +from llmcompressor.utils.metric_logging import ( + get_GPU_memory_usage, + get_layer_size_bytes, +) def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: @@ -41,7 +47,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class = torch.nn.Linear, -) -> Tuple[torch.nn.Parameter, ]: +) -> Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -49,7 +55,7 @@ def quantize_weight( num_columns = weight.shape[1] W = weight.data.clone() - H = compute_hessian(inp, module_class, device=device) + H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -60,8 +66,6 @@ def quantize_weight( tick = time.time() - scale, zero_point = compute_scale_zeropoint(W) - if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( @@ -72,18 +76,23 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = _update_quantization_parameters(quant_args, W) + scale, zero_point = compute_scale_zeropoint(W, quant_args) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = _update_quantization_parameters(quant_args, W) + scale, zero_point = compute_scale_zeropoint(W, quant_args) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] + else: + scale, zero_point = compute_scale_zeropoint(W, quant_args) + else: + scale, zero_point = compute_scale_zeropoint(W, quant_args) + # sparsity mask sparsity = tensor_sparsity(W) preserve_zeros = sparsity >= SPARSITY_THRESHOLD @@ -184,6 +193,7 @@ def quantize_weight( if "METRIC" in logger._core.levels.keys(): _log_metrics(tick, Losses) + has_gidx = False if strategy == QuantizationStrategy.GROUP: if actorder == ActivationOrdering.WEIGHT: # restore original permutation @@ -197,22 +207,15 @@ def quantize_weight( g_idx = g_idx[invperm] # only save g_idx if mapping is not identity - update_parameter_data(self.layer, g_idx, "weight_g_idx") + has_gidx = True + + if not has_gidx: + g_idx = None - if isinstance(self.layer, transformers.Conv1D): + if module_class == transformers.Conv1D: W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += W - - if is_module_offloaded(self.layer): - device = get_offloaded_device(self.layer) - update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) - self.layer._hf_hook.post_forward(self.layer, None) - return W, scale, zero_point, g_idx def free(self): From e01094fed95d4be087c702f841cf687c76347690 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:01:27 +0000 Subject: [PATCH 06/59] compilable --- .../modifiers/quantization/gptq/base.py | 48 +++++++++++++-- .../quantization/gptq/utils/gptq_quantize.py | 60 ++++++------------- .../quantization/gptq/utils/helpers.py | 49 ++++++++++++++- src/llmcompressor/utils/helpers.py | 17 +++++- 4 files changed, 126 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b5da7cce0..490b46450 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -21,10 +21,13 @@ get_output_error, gptq_hook ) +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight +from llmcompressor.modifiers.quantization.gptq.utils.helpers import LogMetrics from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, getattr_chain +from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain +from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -203,8 +206,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.quantization_modifier_.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - - # after lifecycle refactor, all of this moves to pre_batch # find layers (used for printing even if true_sequential=True) # if no targets are provided, default to the modules that shouldn't be @@ -216,12 +217,14 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.num_layers = len(layers) # add hooks to targets and layers + # after lifecycle refactor, move this to pre_batch self.register_hooks(state.model, layers) # apply calibration and trigger hooks (hooks are self removing) self.calibration_forward(state.model, state.data.calib) # freeze quantization + # after lifecycle refactor, move this to post_batch state.model.apply(freeze_module_quantization) return True @@ -291,7 +294,8 @@ def quantize_module(self, name, module, inp): quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight - quantized_weight, scale, zero_point, g_idx = quantize_weight( + with OnloadModule(module), LogMetrics(module) as logger: + losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, quant_args, @@ -302,10 +306,13 @@ def quantize_module(self, name, module, inp): weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) - update_prefix_dict(self.layer, "weight", weight) + if is_module_offloaded(module): + update_prefix_dict(self.layer, "weight", weight) update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") + + logger.set_losses(losses) @contextlib.contextmanager def disable_hooks(self): @@ -329,6 +336,37 @@ def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): self.remove_hooks(child_module) + def _log_metrics(start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) + patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + patch.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) + + + def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index d5b1efef1..4ecdfe837 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -47,7 +47,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class = torch.nn.Linear, -) -> Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -102,7 +102,7 @@ def quantize_weight( else None ) - Losses = torch.zeros(num_columns, device=weight.device) + losses = torch.zeros(num_columns, device=weight.device) # mask dead hessian values dead = torch.diag(H) == 0 @@ -121,7 +121,7 @@ def quantize_weight( W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] if preserve_zeros: @@ -170,7 +170,7 @@ def quantize_weight( # propagate column error Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 + losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) @@ -182,7 +182,7 @@ def quantize_weight( # propagate block error W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 + losses += torch.sum(losses1, 1) / 2 w_err = Err1.matmul(Hinv[i1:i2, i2:]) if preserve_zeros: @@ -190,9 +190,6 @@ def quantize_weight( else: W[:, i2:] -= w_err - if "METRIC" in logger._core.levels.keys(): - _log_metrics(tick, Losses) - has_gidx = False if strategy == QuantizationStrategy.GROUP: if actorder == ActivationOrdering.WEIGHT: @@ -216,39 +213,20 @@ def quantize_weight( W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - return W, scale, zero_point, g_idx - - def free(self): - """ - Free the Hessian memory after the layer is complete - """ - delattr(self, "H") - super().free() - - def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): - """ - Update layer quantization parameters with potentially permuted weight - - :param args: quantization arguments - :param W: weight to calculate quantization parameters from - """ - observer = args.get_observer() - _scale, _zero_point = observer(W, g_idx=None) - update_parameter_data(self.layer, _scale, "weight_scale") - update_parameter_data(self.layer, _zero_point, "weight_zero_point") - - def _apply_activation_ordering( - self, W: torch.Tensor, H: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute weight and hessian in order of greatest outupt activations - - :param W: weight to permute - """ - perm = torch.argsort(torch.diag(H), descending=True) - return W[:, perm], H[perm][:, perm], perm - -def _log_metrics(self, start_tick: float, losses: torch.Tensor): + return losses, W, scale, zero_point, g_idx + +def _apply_activation_ordering( + W: torch.Tensor, H: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Permute weight and hessian in order of greatest outupt activations + + :param W: weight to permute + """ + perm = torch.argsort(torch.diag(H), descending=True) + return W[:, perm], H[perm][:, perm], perm + +def _log_metrics(start_tick: float, losses: torch.Tensor): """ Log metrics related to compression algorithm diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index f226e41c0..c15816892 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -1,6 +1,10 @@ from typing import Any, Iterable, List, Tuple, Union +import time import torch +from loguru import logger + +from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes __all__ = ["get_output_error", "gptq_hook"] @@ -58,4 +62,47 @@ def wrapped(self, *args, **kwargs): func(self, *args, **kwargs) - return wrapped \ No newline at end of file + return wrapped + + +class LogMetrics: + def __init__(self, module: torch.nn.Module): + self.module = module + self.start_tick = None + self.losses = None + + def set_losses(self, losses: torch.Tensor): + self.losses = losses + + def __enter__(self): + self.start_tick = time.time() + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + + if self.start_tick is not None: + patch.log("METRIC", "time %.2f" % (time.time() - self.start_tick)) + if self.losses is not None: + patch.log("METRIC", "error %.2f" % torch.sum(self.losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + compressed_size = get_layer_size_bytes(self.module) + patch.log("METRIC", f"Compressed layer size: {compressed_size} MB") diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 0305c04df..b46685110 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -29,6 +29,7 @@ disable_quantization, enable_quantization, ) +from compressed_tensors import is_module_offloaded __all__ = [ "ALL_TOKEN", @@ -1085,4 +1086,18 @@ def __enter__(self): self.model.apply(disable_quantization) def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.model.apply(enable_quantization) \ No newline at end of file + self.model.apply(enable_quantization) + + +class OnloadModule: + def __init__(self, module: torch.nn.Module): + self.module = module + self.is_module_offloaded = is_module_offloaded(self.module) + + def __enter__(self): + if self.is_module_offloaded: + self.module._hf_hook.pre_forward(self.module) + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + if self.is_module_offloaded: + self.module._hf_hook.post_forward(self.module, None) \ No newline at end of file From ad9f5a8d6100027ab045b0d44cc068d59c041c33 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:02:00 +0000 Subject: [PATCH 07/59] compilable --- .../modifiers/quantization/gptq/base.py | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 490b46450..cfcdfc529 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -335,38 +335,6 @@ def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): for child_module in module.children(): self.remove_hooks(child_module) - - def _log_metrics(start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) - - - def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, From e4ee0af5c3c32be52329f39d2f51f83a857cb656 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:52:13 +0000 Subject: [PATCH 08/59] wip --- .../modifiers/quantization/gptq/base.py | 73 ++-- .../quantization/gptq/utils/__init__.py | 2 +- .../quantization/gptq/utils/gptq_quantize.py | 29 -- .../quantization/gptq/utils/gptq_wrapper.py | 341 ------------------ .../quantization/gptq/utils/helpers.py | 2 +- 5 files changed, 50 insertions(+), 397 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index cfcdfc529..804359957 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -6,28 +6,22 @@ from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, - disable_quantization, - enable_quantization, freeze_module_quantization, ) from loguru import logger from pydantic import Field, field_validator -from torch.nn import Module from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.gptq.utils import ( - GPTQWrapper, get_output_error, gptq_hook ) from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import LogMetrics -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.utils.fsdp.context import fix_fsdp_module_name from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain -from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -123,6 +117,11 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None + _layer_index: int = 0 + _num_layers: int = 0 + _hooks_disabled: bool = False + quantization_modifier_: Optional[QuantizationModifier] = None + @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: if not value: @@ -137,8 +136,8 @@ def validate_sequential_update(cls, value: bool) -> bool: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.current_layer_index = 0 - self.num_layers = 0 + self._layer_index = 0 + self._num_layers = 0 self.quantization_modifier_ = None def on_initialize_structure(self, state: State, **kwargs): @@ -214,7 +213,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if self.sequential_targets is None: self.sequential_targets = get_no_split_params(state.model) layers = get_layers(self.sequential_targets, state.model) - self.num_layers = len(layers) + self._num_layers = len(layers) # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch @@ -243,8 +242,6 @@ def on_finalize(self, state: "State", **kwargs) -> bool: return True def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): - layers = layers.values() - for name, module in model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) @@ -256,37 +253,63 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + + def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + import torch.nn.functional as F + + accumulated_data = {} # Dictionary to accumulate samples per key - def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dataloader): - all_data = torch.cat([batch for batch in data], dim=0) + def pad_tensor(tensor, max_len): + """Pads a tensor to the specified max_len along the second dimension (sequence length).""" + pad_size = max_len - tensor.size(1) # Calculate the padding size + return F.pad(tensor, (0, pad_size), value=0) # Pad on the right with zeros + + for batch in dataloader: + for key, value in batch.items(): + if key not in accumulated_data: + accumulated_data[key] = [] + accumulated_data[key].append(value) # Accumulate values for each key + + # Find maximum length for each key across all samples to ensure matching shapes + max_lengths = {} + for key, tensors in accumulated_data.items(): + max_lengths[key] = max([tensor.size(1) for tensor in tensors]) # Assuming the second dimension is the sequence length + + # Pad and concatenate for each key + concatenated_batch = { + key: torch.cat([pad_tensor(tensor, max_lengths[key]) for tensor in accumulated_data[key]], dim=0) + for key in accumulated_data + } + with DisableKVCache(model), DisableQuantization(model): - model(all_data) + model(**concatenated_batch) @gptq_hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): + def target_pre_forward(self, name: str, module: torch.nn.Module, args): if self.true_sequential: # compress first so output is from quantized weights self.quantize_module(name, module, args) @gptq_hook - def target_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): + def target_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any): if not self.true_sequential: # compress after so output is from unquantized weights self.quantize_module(name, module, args) @gptq_hook - def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): - logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") + def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): + logger.info(f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====") + breakpoint() @gptq_hook - def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): + def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], output: Any): if not self.true_sequential: # rerun with (now) quantized weights with self.disable_hooks(): - output = module(*args, **kwargs) + output = module(args, **kwargs) - self.layer_index += 1 + self._layer_index += 1 return output def quantize_module(self, name, module, inp): @@ -317,10 +340,10 @@ def quantize_module(self, name, module, inp): @contextlib.contextmanager def disable_hooks(self): try: - self.hooks_disabled = True + self._hooks_disabled = True yield finally: - self.hooks_disabled = False + self._hooks_disabled = False def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): if hasattr(module, "_gptq_pre_hook"): diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py index a8673dfc2..5703ced46 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa -from .gptq_wrapper import * +from .gptq_quantize import * from .helpers import * diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 4ecdfe837..512741888 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -225,32 +225,3 @@ def _apply_activation_ordering( """ perm = torch.argsort(torch.diag(H), descending=True) return W[:, perm], H[perm][:, perm], perm - -def _log_metrics(start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py deleted file mode 100644 index d53b942eb..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ /dev/null @@ -1,341 +0,0 @@ -import time -from typing import Tuple - -from compressed_tensors.quantization import ( - ActivationOrdering, - QuantizationArgs, - QuantizationStrategy, -) -from compressed_tensors.quantization.lifecycle.forward import fake_quantize - -from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD -from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from llmcompressor.pytorch.utils.helpers import tensor_sparsity -from llmcompressor.utils import getattr_chain -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) - -try: - import transformers -except ImportError as err: - transformers = None - transformers_err = err - -import math -from copy import copy - -import torch -import torch.nn as nn -from compressed_tensors.utils import ( - get_offloaded_device, - is_module_offloaded, - update_parameter_data, - update_prefix_dict, -) -from loguru import logger - -__all__ = ["GPTQWrapper"] - - -class GPTQWrapper(ModuleCompressionWrapper): - """ - Runs GPTQ on a single module that contains no sub-modules - - Lifecycle: - - add_batch - - compress - - free - - :param name: name of module to run compression on - :param layer: module to run compression on - """ - - def __init__(self, name, layer): - super().__init__(name=name, layer=layer) - - # for Hessian calculation - self.register_buffer( - "H", - torch.zeros( - (self.columns, self.columns), device=self.dev, dtype=torch.float32 - ), - ) - - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - - :param inp: tensor containing layer input - :param out: tensor containing layer output - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = inp.to(dtype=self.H.dtype) - inp = math.sqrt(2 / self.nsamples) * inp - self.H += inp.matmul(inp.t()) - - def compress( - self, - blocksize: int = 128, - percdamp: float = 0.01, - ): - """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. - - :param blocksize: Number of columns to compress in one pass - :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm - """ - args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(self.layer, args_loc, None) - if weight_quant_args is None: - logger.debug(f"Skipping unquantized layer {self.name}...") - return - - if is_module_offloaded(self.layer): - self.layer._hf_hook.pre_forward(self.layer) - - strategy = weight_quant_args.strategy - actorder = weight_quant_args.actorder - final_shape = self.layer.weight.shape - final_dtype = self.layer.weight.dtype - W = self.layer.weight.data.clone() - - # standardize shape and dtype - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - elif isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.float() - - tick = time.time() - - if strategy == QuantizationStrategy.GROUP: - # mapping from column index to group index - g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size - ) - - if actorder == ActivationOrdering.GROUP: - # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) - - # use identity g_idx (invert permutation later) - - elif actorder == ActivationOrdering.WEIGHT: - # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) - - # permute g_idx to maintain identity mapping after unpermutation - g_idx = g_idx[perm] - - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - - # sparsity mask - sparsity = tensor_sparsity(W) - preserve_zeros = sparsity >= SPARSITY_THRESHOLD - W_nz_mask = ( - (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() - if preserve_zeros - else None - ) - - # mask dead hessian values - dead = torch.diag(self.H) == 0 - self.H[dead, dead] = 1 - W[:, dead] = 0 - - Losses = torch.zeros(self.rows, device=self.dev) - - # compute inverse hessian in place to save memory - damp = percdamp * torch.mean(torch.diag(self.H)) - diag = torch.arange(self.columns, device=self.dev) - self.H[diag, diag] += damp - self.H = torch.linalg.cholesky(self.H) - self.H = torch.cholesky_inverse(self.H) - self.H = torch.linalg.cholesky(self.H, upper=True) - Hinv = self.H - - # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - if preserve_zeros: - W1_nz_mask = W_nz_mask[:, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = w.clone() - - # quantize column - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - self.layer.quantization_scheme.weights, - ) - elif strategy == QuantizationStrategy.CHANNEL: - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - weight_quant_args, - ) - elif strategy == QuantizationStrategy.GROUP: - # get the group index for the current column - column_idx = i1 + i - group_index = g_idx[column_idx] - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(weight_quant_args) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, group_index], - zero_point[:, group_index], - altered_qargs, - ) - else: - raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" - ) - - # propagate column error - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if preserve_zeros: - W1[:, i:] -= w1_err * W1_nz_mask[:, i:] - else: - W1[:, i:] -= w1_err - Err1[:, i] = err1 - - # propagate block error - W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 - - w_err = Err1.matmul(Hinv[i1:i2, i2:]) - if preserve_zeros: - W[:, i2:] -= w_err * W_nz_mask[:, i2:] - else: - W[:, i2:] -= w_err - - if "METRIC" in logger._core.levels.keys(): - self._log_metrics(tick, Losses) - - if strategy == QuantizationStrategy.GROUP: - if actorder == ActivationOrdering.WEIGHT: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - - elif actorder == ActivationOrdering.GROUP: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - g_idx = g_idx[invperm] - - # only save g_idx if mapping is not identity - update_parameter_data(self.layer, g_idx, "weight_g_idx") - - if isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.reshape(final_shape).to(final_dtype) - - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += W - - if is_module_offloaded(self.layer): - device = get_offloaded_device(self.layer) - update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) - self.layer._hf_hook.post_forward(self.layer, None) - - def free(self): - """ - Free the Hessian memory after the layer is complete - """ - delattr(self, "H") - super().free() - - def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): - """ - Update layer quantization parameters with potentially permuted weight - - :param args: quantization arguments - :param W: weight to calculate quantization parameters from - """ - observer = args.get_observer() - _scale, _zero_point = observer(W, g_idx=None) - update_parameter_data(self.layer, _scale, "weight_scale") - update_parameter_data(self.layer, _zero_point, "weight_zero_point") - - def _apply_activation_ordering( - self, W: torch.Tensor, H: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute weight and hessian in order of greatest outupt activations - - :param W: weight to permute - """ - perm = torch.argsort(torch.diag(H), descending=True) - return W[:, perm], H[perm][:, perm], perm - - def _log_metrics(self, start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index c15816892..413f5eaca 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -57,7 +57,7 @@ def get_output_error( def gptq_hook(func): def wrapped(self, *args, **kwargs): - if self.hooks_disabled: + if self._hooks_disabled: return func(self, *args, **kwargs) From d9ba539739f4e4cd3acfdde2df36f58e66a6bfc7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:52:34 +0000 Subject: [PATCH 09/59] add example --- examples/quantization_w4a16/llama3_example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index d587a6199..01d9dba8c 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,3 +1,4 @@ +import torch from datasets import load_dataset from transformers import AutoTokenizer @@ -5,8 +6,9 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" #MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, From 83a5762c932dc69d6bc9aa714ff39f5f1149b2e1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 23:39:07 +0000 Subject: [PATCH 10/59] wip --- .../modifiers/quantization/gptq/base.py | 24 +- .../quantization/gptq/utils/gptq_quantize.py | 238 +++++++++--------- .../quantization/gptq/utils/helpers.py | 7 +- 3 files changed, 139 insertions(+), 130 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 804359957..3d70b2d40 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -18,7 +18,7 @@ gptq_hook ) from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight -from llmcompressor.modifiers.quantization.gptq.utils.helpers import LogMetrics +from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.utils.fsdp.context import fix_fsdp_module_name from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain @@ -106,6 +106,7 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True + true_sequential: bool = False targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -256,12 +257,12 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + """ import torch.nn.functional as F accumulated_data = {} # Dictionary to accumulate samples per key def pad_tensor(tensor, max_len): - """Pads a tensor to the specified max_len along the second dimension (sequence length).""" pad_size = max_len - tensor.size(1) # Calculate the padding size return F.pad(tensor, (0, pad_size), value=0) # Pad on the right with zeros @@ -281,9 +282,12 @@ def pad_tensor(tensor, max_len): key: torch.cat([pad_tensor(tensor, max_lengths[key]) for tensor in accumulated_data[key]], dim=0) for key in accumulated_data } + """ + + batch = next(iter(dataloader)) with DisableKVCache(model), DisableQuantization(model): - model(**concatenated_batch) + model(**batch) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -300,7 +304,6 @@ def target_post_forward(self, name: str, module: torch.nn.Module, args: torch.Te @gptq_hook def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): logger.info(f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====") - breakpoint() @gptq_hook def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], output: Any): @@ -312,22 +315,25 @@ def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Ten self._layer_index += 1 return output - def quantize_module(self, name, module, inp): + def quantize_module(self, name, module, args): logger.info(f"Compressing {name}...") + inp = args[0] # Assume that first argument is input (true for most Module types) quant_args = getattr_chain(module, "quantization_scheme.weights") + # with onloaded weight - with OnloadModule(module), LogMetrics(module) as logger: + with OnloadModule(module), MetricsLogger(module) as metrics_logger: losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, quant_args, - block_size=self.block_size, + blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), ) - weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) + #weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) + weight = quantized_weight if is_module_offloaded(module): update_prefix_dict(self.layer, "weight", weight) @@ -335,7 +341,7 @@ def quantize_module(self, name, module, inp): update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") - logger.set_losses(losses) + metrics_logger.set_losses(losses) @contextlib.contextmanager def disable_hooks(self): diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 512741888..2f0d3120f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -7,10 +7,9 @@ from copy import copy from loguru import logger from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering, fake_quantize -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) +from compressed_tensors.quantization.observers import MovingAverageMinMaxObserver +from llmcompressor.pytorch.utils.helpers import tensor_sparsity +from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: @@ -40,6 +39,10 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H +def compute_scale_zeropoint(W: torch.Tensor, quant_args: QuantizationArgs) -> Tuple[torch.Tensor, torch.Tensor]: + return MovingAverageMinMaxObserver(quant_args)(W) + + def quantize_weight( weight: torch.Tensor, inp: torch.Tensor, @@ -52,7 +55,6 @@ def quantize_weight( actorder = quant_args.actorder final_shape = weight.shape final_dtype = weight.dtype - num_columns = weight.shape[1] W = weight.data.clone() H = compute_hessian(inp, module_class, device=weight.device) @@ -63,8 +65,7 @@ def quantize_weight( elif module_class == transformers.Conv1D: W.transpose_(0, 1) W = W.to(dtype=torch.float32) - - tick = time.time() + num_columns = W.shape[0] if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index @@ -93,127 +94,128 @@ def quantize_weight( else: scale, zero_point = compute_scale_zeropoint(W, quant_args) - # sparsity mask - sparsity = tensor_sparsity(W) - preserve_zeros = sparsity >= SPARSITY_THRESHOLD - W_nz_mask = ( - (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() - if preserve_zeros - else None - ) - - losses = torch.zeros(num_columns, device=weight.device) - - # mask dead hessian values - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - # compute inverse hessian in place to save memory - # TODO: check in place - Hinv = invert_hessian(H, percdamp) - - # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, num_columns, blocksize): - i2 = min(i1 + blocksize, num_columns) - count = i2 - i1 + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + + losses = torch.zeros(num_columns, device=weight.device) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + # TODO: check in place + Hinv = invert_hessian(H, percdamp) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + count = i2 - i1 + print((i1, i2, num_columns)) + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + + # quantize column + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + quant_args, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + group_index = g_idx[column_idx] + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, group_index], + zero_point[:, group_index], + altered_qargs, + ) + else: + raise ValueError( + "Quantization strategy is not supported for GPTQ: " + f"{strategy}" + ) - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] + # propagate column error + Q1[:, i] = q + losses1[:, i] = (w - q) ** 2 / d**2 + err1 = (w - q) / d + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) if preserve_zeros: - W1_nz_mask = W_nz_mask[:, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = w.clone() - - # quantize column - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - quant_args, - ) - elif strategy == QuantizationStrategy.CHANNEL: - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - quant_args, - ) - elif strategy == QuantizationStrategy.GROUP: - # get the group index for the current column - column_idx = i1 + i - group_index = g_idx[column_idx] - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(quant_args) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, group_index], - zero_point[:, group_index], - altered_qargs, - ) - else: - raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" - ) - - # propagate column error - Q1[:, i] = q - losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if preserve_zeros: - W1[:, i:] -= w1_err * W1_nz_mask[:, i:] - else: - W1[:, i:] -= w1_err - Err1[:, i] = err1 - - # propagate block error - W[:, i1:i2] = Q1 - losses += torch.sum(losses1, 1) / 2 - - w_err = Err1.matmul(Hinv[i1:i2, i2:]) - if preserve_zeros: - W[:, i2:] -= w_err * W_nz_mask[:, i2:] + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] else: - W[:, i2:] -= w_err + W1[:, i:] -= w1_err + Err1[:, i] = err1 - has_gidx = False - if strategy == QuantizationStrategy.GROUP: - if actorder == ActivationOrdering.WEIGHT: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] + # propagate block error + W[:, i1:i2] = Q1 + losses += torch.sum(losses1, 1) / 2 - elif actorder == ActivationOrdering.GROUP: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - g_idx = g_idx[invperm] + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err - # only save g_idx if mapping is not identity - has_gidx = True + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] - if not has_gidx: - g_idx = None + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] - if module_class == transformers.Conv1D: - W.transpose_(0, 1) - W = W.reshape(final_shape).to(final_dtype) + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if module_class == transformers.Conv1D: + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) - return losses, W, scale, zero_point, g_idx + return losses, W, scale, zero_point, g_idx def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index 413f5eaca..6ebb1dc7a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -6,7 +6,7 @@ from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes -__all__ = ["get_output_error", "gptq_hook"] +__all__ = ["get_output_error", "gptq_hook", "MetricsLogger"] def get_output_error( @@ -65,7 +65,7 @@ def wrapped(self, *args, **kwargs): return wrapped -class LogMetrics: +class MetricsLogger: def __init__(self, module: torch.nn.Module): self.module = module self.start_tick = None @@ -74,8 +74,9 @@ def __init__(self, module: torch.nn.Module): def set_losses(self, losses: torch.Tensor): self.losses = losses - def __enter__(self): + def __enter__(self) -> "MetricsLogger": self.start_tick = time.time() + return self def __exit__(self, _exc_type, _exc_val, _exc_tb): """ From 7f49ab40c245bea5a8350b479856dd5ced9fb573 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 23:50:33 +0000 Subject: [PATCH 11/59] runnable --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 3d70b2d40..7b56f0e05 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -310,7 +310,7 @@ def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Ten if not self.true_sequential: # rerun with (now) quantized weights with self.disable_hooks(): - output = module(args, **kwargs) + output = module(*args, **kwargs) self._layer_index += 1 return output diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 2f0d3120f..ebe657ae4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -65,7 +65,8 @@ def quantize_weight( elif module_class == transformers.Conv1D: W.transpose_(0, 1) W = W.to(dtype=torch.float32) - num_columns = W.shape[0] + num_rows = W.shape[0] + num_columns = W.shape[1] if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index @@ -103,7 +104,7 @@ def quantize_weight( else None ) - losses = torch.zeros(num_columns, device=weight.device) + losses = torch.zeros(num_rows, device=weight.device) # mask dead hessian values dead = torch.diag(H) == 0 @@ -118,7 +119,6 @@ def quantize_weight( for i1 in range(0, num_columns, blocksize): i2 = min(i1 + blocksize, num_columns) count = i2 - i1 - print((i1, i2, num_columns)) W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) @@ -166,8 +166,7 @@ def quantize_weight( ) else: raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" + f"Quantization strategy is not supported for GPTQ: {strategy}" ) # propagate column error From ac0d9266b1bb2af468dd8950646a0d1f1773ea41 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 20:30:07 +0000 Subject: [PATCH 12/59] batching --- .../modifiers/quantization/gptq/base.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7b56f0e05..1054dc436 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -20,6 +20,7 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data from llmcompressor.utils.fsdp.context import fix_fsdp_module_name from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain from llmcompressor.utils.pytorch.module import ( @@ -257,37 +258,28 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): - """ import torch.nn.functional as F - - accumulated_data = {} # Dictionary to accumulate samples per key - - def pad_tensor(tensor, max_len): - pad_size = max_len - tensor.size(1) # Calculate the padding size - return F.pad(tensor, (0, pad_size), value=0) # Pad on the right with zeros - - for batch in dataloader: - for key, value in batch.items(): - if key not in accumulated_data: - accumulated_data[key] = [] - accumulated_data[key].append(value) # Accumulate values for each key - - # Find maximum length for each key across all samples to ensure matching shapes - max_lengths = {} - for key, tensors in accumulated_data.items(): - max_lengths[key] = max([tensor.size(1) for tensor in tensors]) # Assuming the second dimension is the sequence length - - # Pad and concatenate for each key - concatenated_batch = { - key: torch.cat([pad_tensor(tensor, max_lengths[key]) for tensor in accumulated_data[key]], dim=0) - for key in accumulated_data - } - """ - - batch = next(iter(dataloader)) + from torch.nn.utils.rnn import pad_sequence + + dataset = dataloader.dataset + def collate_fn(batch): + # Extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item['input_ids']) for item in batch] + attention_masks = [torch.tensor(item['attention_mask']) for item in batch] + + # Pad sequences in the batch + padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0) + + return { + 'input_ids': padded_input_ids, + 'attention_mask': padded_attention_masks + } + dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True, collate_fn=collate_fn) + data = next(iter(dataloader)) with DisableKVCache(model), DisableQuantization(model): - model(**batch) + model(**data) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -362,7 +354,7 @@ def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): if recurse: for child_module in module.children(): - self.remove_hooks(child_module) + self.remove_gptq_hooks(child_module) def _build_quant_modifier(self): """ From 63049739f2e0fbb723dd9000e057a4000487c617 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 22:12:06 +0000 Subject: [PATCH 13/59] calibration forward context --- .../modifiers/quantization/gptq/base.py | 36 +++++++------ .../modifiers/utils/layer_compressor.py | 17 +++++++ src/llmcompressor/utils/helpers.py | 51 ++++++++++++------- 3 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1054dc436..342e194cd 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch.nn.utils.rnn import pad_sequence import contextlib from functools import partial from compressed_tensors.quantization import ( @@ -20,9 +21,10 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain +from llmcompressor.utils.helpers import calibration_forward_context, align_module, getattr_chain from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -258,28 +260,32 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): - import torch.nn.functional as F - from torch.nn.utils.rnn import pad_sequence - dataset = dataloader.dataset def collate_fn(batch): - # Extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item['input_ids']) for item in batch] - attention_masks = [torch.tensor(item['attention_mask']) for item in batch] + # extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item["input_ids"]) for item in batch] + attention_masks = [torch.tensor(item["attention_mask"]) for item in batch] - # Pad sequences in the batch + # pad sequences in the batch padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0) return { - 'input_ids': padded_input_ids, - 'attention_mask': padded_attention_masks + "input_ids": padded_input_ids, + "attention_mask": padded_attention_masks } - dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True, collate_fn=collate_fn) - data = next(iter(dataloader)) - with DisableKVCache(model), DisableQuantization(model): - model(**data) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=len(dataset), + shuffle=True, + collate_fn=collate_fn, + pin_memory=True + ) + + calibration_data = next(iter(dataloader)) + with calibration_forward_context(model): + model(**calibration_data) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -314,7 +320,7 @@ def quantize_module(self, name, module, args): quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight - with OnloadModule(module), MetricsLogger(module) as metrics_logger: + with align_module(module), MetricsLogger(module) as metrics_logger: losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 3dd3caa7e..714d328df 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -20,6 +20,23 @@ __all__ = ["LayerCompressor"] +class LayerCompressorMixin: + def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): + return + for name, module in model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook) + + if module in layers.values(): + pre_hook = partial(self.layer_pre_forward, name) + post_hook = partial(self.layer_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + + class LayerCompressor: """ Runs weight sparisification on a single layer using calibration data inputs. The diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index b46685110..db4846b7b 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -15,10 +15,11 @@ import sys import tarfile import warnings +import contextlib from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Optional from urllib.parse import urlparse import numpy @@ -1078,26 +1079,40 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): self.config.use_cache = self.restore_value -class DisableQuantization: - def __init__(self, model: torch.nn.Module): - self.model = model +@contextlib.contextmanager +def DisableQuantization(model: torch.nn.Module): + model.apply(disable_quantization) + yield + model.apply(enable_quantization) - def __enter__(self): - self.model.apply(disable_quantization) - def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.model.apply(enable_quantization) +def calibration_forward_context(model: torch.nn.Module): + torch.eval() + with ( + torch.no_grad(), + DisableKVCache(model), + DisableQuantization(model), + ): + yield -class OnloadModule: - def __init__(self, module: torch.nn.Module): - self.module = module - self.is_module_offloaded = is_module_offloaded(self.module) - def __enter__(self): - if self.is_module_offloaded: - self.module._hf_hook.pre_forward(self.module) +@contextlib.contextmanager +def align_module(module: torch.nn.Module, device: Optional[torch.device] = None): + """ + Move an offloaded module's parameters to device or module execution device - def __exit__(self, _exc_type, _exc_val, _exc_tb): - if self.is_module_offloaded: - self.module._hf_hook.post_forward(self.module, None) \ No newline at end of file + :param module: module with parameters to align + :param device: optional device to move parameters to, if None is provided then + module execution device will be used + """ + if device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = device + + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, torch.tensor([])) + + if device is not None: + module._hf_hook.execution_device = original_device \ No newline at end of file From 868a480d9c3ae076dec8861bbcb03bc03b6b799b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 22:31:08 +0000 Subject: [PATCH 14/59] fix stuff --- examples/quantization_w4a16/llama3_example.py | 5 +++-- src/llmcompressor/modifiers/quantization/gptq/base.py | 3 +-- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- src/llmcompressor/utils/helpers.py | 3 ++- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 01d9dba8c..56aef6b7a 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -23,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 512 // 4 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -44,10 +44,11 @@ def preprocess(example): # Tokenize inputs. +tokenizer.add_special_tokens({'pad_token': '[PAD]'}) def tokenize(sample): return tokenizer( sample["text"], - padding=False, + padding=True, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 342e194cd..77cf3c605 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -283,9 +283,8 @@ def collate_fn(batch): pin_memory=True ) - calibration_data = next(iter(dataloader)) with calibration_forward_context(model): - model(**calibration_data) + run_calibration_forward(model, dataloader, mask_padding=True) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9003ff22d..20abaf376 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -102,7 +102,7 @@ def run_calibration_forward( # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass - torch.cuda.empty_cache() + #torch.cuda.empty_cache() return intermediates diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index db4846b7b..14c724320 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1086,8 +1086,9 @@ def DisableQuantization(model: torch.nn.Module): model.apply(enable_quantization) +@contextlib.contextmanager def calibration_forward_context(model: torch.nn.Module): - torch.eval() + model.eval() with ( torch.no_grad(), From 86c8a06dae722f289360737600d67714359ac797 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 22:44:45 +0000 Subject: [PATCH 15/59] wip --- examples/quantization_w4a16/llama3_example.py | 4 ++-- src/llmcompressor/modifiers/quantization/gptq/base.py | 1 + .../modifiers/quantization/gptq/utils/gptq_quantize.py | 9 +++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 56aef6b7a..fbb1f2e2c 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -24,7 +24,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 512 // 4 -MAX_SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 2048 // 2 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) @@ -59,7 +59,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01) # Apply algorithms. oneshot( diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 77cf3c605..5cfd036a0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -283,6 +283,7 @@ def collate_fn(batch): pin_memory=True ) + breakpoint() with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index ebe657ae4..203dac5f0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -12,19 +12,24 @@ from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +GPTQ_PRECISION = torch.float32 + + def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: inp = inp.to(device=device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) + breakpoint() if module_class in (torch.nn.Linear, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() nsamples = inp.shape[0] + breakpoint() - inp = inp.to(dtype=torch.float32) + inp = inp.to(dtype=GPTQ_PRECISION) inp = math.sqrt(2 / nsamples) * inp return inp.matmul(inp.t()) @@ -64,7 +69,7 @@ def quantize_weight( W = W.flatten(1) elif module_class == transformers.Conv1D: W.transpose_(0, 1) - W = W.to(dtype=torch.float32) + W = W.to(dtype=GPTQ_PRECISION) num_rows = W.shape[0] num_columns = W.shape[1] From 130517354f465253d1669a7ac70b2dbb9c85a905 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 23:01:07 +0000 Subject: [PATCH 16/59] use hooks list --- examples/quantization_w4a16/llama3_example.py | 2 +- .../modifiers/quantization/gptq/base.py | 28 +++++++------------ .../quantization/gptq/utils/gptq_quantize.py | 7 ++--- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index fbb1f2e2c..2568c59ed 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -23,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 // 4 +NUM_CALIBRATION_SAMPLES = 512 // 6 MAX_SEQUENCE_LENGTH = 2048 // 2 # Load dataset and preprocess. diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 5cfd036a0..216d9e6cb 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -3,6 +3,7 @@ import torch from torch.nn.utils.rnn import pad_sequence +from torch.utils.hooks import RemovableHandle import contextlib from functools import partial from compressed_tensors.quantization import ( @@ -125,6 +126,7 @@ class GPTQModifier(Modifier): _num_layers: int = 0 _hooks_disabled: bool = False quantization_modifier_: Optional[QuantizationModifier] = None + _hooks: List[RemovableHandle] = [] @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -241,7 +243,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) - self.remove_gptq_hooks(state.model) + self.remove_gptq_hooks() return True @@ -250,14 +252,14 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append(module.register_forward_hook(post_hook)) if module in layers.values(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append(module.register_forward_hook(post_hook, with_kwargs=True)) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): dataset = dataloader.dataset @@ -283,7 +285,6 @@ def collate_fn(batch): pin_memory=True ) - breakpoint() with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) @@ -349,18 +350,9 @@ def disable_hooks(self): finally: self._hooks_disabled = False - def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): - if hasattr(module, "_gptq_pre_hook"): - module._gptq_pre_hook.remove() - delattr(module, "_gptq_pre_hook") - - if hasattr(module, "_gptq_post_hook"): - module._gptq_post_hook.remove() - delattr(module, "_gptq_post_hook") - - if recurse: - for child_module in module.children(): - self.remove_gptq_hooks(child_module) + def remove_gptq_hooks(self): + for hook in self._hooks: + hook.remove() def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 203dac5f0..8e87f3ee0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -20,15 +20,14 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: if len(inp.shape) == 2: inp = inp.unsqueeze(0) - breakpoint() + nsamples = inp.shape[0] # note this is the number of dataset samples, not + # multiplied by the sequence length + if module_class in (torch.nn.Linear, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() - nsamples = inp.shape[0] - breakpoint() - inp = inp.to(dtype=GPTQ_PRECISION) inp = math.sqrt(2 / nsamples) * inp return inp.matmul(inp.t()) From e6adc5a9a823cd5dcc615ad210a2c639bf09f7ce Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 18:56:01 +0000 Subject: [PATCH 17/59] layer compressor --- .../modifiers/quantization/gptq/base.py | 147 ++++++------------ .../quantization/gptq/utils/gptq_quantize.py | 32 ++-- .../quantization/gptq/utils/helpers.py | 9 +- .../modifiers/utils/layer_compressor.py | 124 +++++++++++++-- .../modifiers/utils/pytorch_helpers.py | 2 +- src/llmcompressor/utils/helpers.py | 13 +- 6 files changed, 187 insertions(+), 140 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 216d9e6cb..1b1e56f23 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,43 +1,35 @@ -import gc -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch -from torch.nn.utils.rnn import pad_sequence -from torch.utils.hooks import RemovableHandle -import contextlib -from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, freeze_module_quantization, ) +from compressed_tensors.utils import ( + is_module_offloaded, + update_parameter_data, + update_prefix_dict, +) from loguru import logger from pydantic import Field, field_validator +from torch.nn.utils.rnn import pad_sequence +from torch.utils.hooks import RemovableHandle from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.gptq.utils import ( - get_output_error, - gptq_hook +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( + quantize_weight, ) -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data -from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import calibration_forward_context, align_module, getattr_chain -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - qat_active, -) - -from compressed_tensors.utils import ( - get_offloaded_device, - is_module_offloaded, - update_parameter_data, - update_prefix_dict, +from llmcompressor.utils.helpers import ( + align_module, + calibration_forward_context, + getattr_chain, ) +from llmcompressor.utils.pytorch.module import qat_active __all__ = ["GPTQModifier"] @@ -138,13 +130,16 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return value - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._layer_index = 0 self._num_layers = 0 self.quantization_modifier_ = None + self.layer_compressor = LayerCompressor( + self.quantize_module, self.true_sequential + ) def on_initialize_structure(self, state: State, **kwargs): """ @@ -212,18 +207,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # find layers (used for printing even if true_sequential=True) - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - if self.sequential_targets is None: - self.sequential_targets = get_no_split_params(state.model) - layers = get_layers(self.sequential_targets, state.model) - self._num_layers = len(layers) - # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch - self.register_hooks(state.model, layers) + self.layer_compressor.register_hooks(state.model, self.sequential_targets) # apply calibration and trigger hooks (hooks are self removing) self.calibration_forward(state.model, state.data.calib) @@ -233,7 +219,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: state.model.apply(freeze_module_quantization) return True - + def on_finalize(self, state: "State", **kwargs) -> bool: """ disable the quantization observers used by the OBCQ algorithm @@ -243,81 +229,50 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) - self.remove_gptq_hooks() + self.layer_compressor.remove_hooks() return True - - def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): - for name, module in model.named_modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - pre_hook = partial(self.target_pre_forward, name) - post_hook = partial(self.target_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append(module.register_forward_hook(post_hook)) - - if module in layers.values(): - pre_hook = partial(self.layer_pre_forward, name) - post_hook = partial(self.layer_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append(module.register_forward_hook(post_hook, with_kwargs=True)) - - def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + + def calibration_forward( + self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader + ): dataset = dataloader.dataset + def collate_fn(batch): # extract input_ids and attention_mask from the batch input_ids = [torch.tensor(item["input_ids"]) for item in batch] attention_masks = [torch.tensor(item["attention_mask"]) for item in batch] - + # pad sequences in the batch - padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) - padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0) + padded_input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=0 + ) + padded_attention_masks = pad_sequence( + attention_masks, batch_first=True, padding_value=0 + ) return { "input_ids": padded_input_ids, - "attention_mask": padded_attention_masks + "attention_mask": padded_attention_masks, } - + dataloader = torch.utils.data.DataLoader( dataset, batch_size=len(dataset), shuffle=True, collate_fn=collate_fn, - pin_memory=True + pin_memory=True, ) - + with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) - @gptq_hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args): - if self.true_sequential: - # compress first so output is from quantized weights - self.quantize_module(name, module, args) - - @gptq_hook - def target_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any): - if not self.true_sequential: - # compress after so output is from unquantized weights - self.quantize_module(name, module, args) - - @gptq_hook - def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): - logger.info(f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====") - - @gptq_hook - def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], output: Any): - if not self.true_sequential: - # rerun with (now) quantized weights - with self.disable_hooks(): - output = module(*args, **kwargs) - - self._layer_index += 1 - return output - def quantize_module(self, name, module, args): logger.info(f"Compressing {name}...") - inp = args[0] # Assume that first argument is input (true for most Module types) + inp = args[ + 0 + ] # Assume that first argument is input (true for most Module types) quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight @@ -330,10 +285,10 @@ def quantize_module(self, name, module, args): percdamp=self.dampening_frac, module_class=type(module), ) - - #weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) + + # weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) weight = quantized_weight - + if is_module_offloaded(module): update_prefix_dict(self.layer, "weight", weight) update_parameter_data(module, scale, "weight_scale") @@ -341,18 +296,6 @@ def quantize_module(self, name, module, args): update_parameter_data(module, g_idx, "weight_g_idx") metrics_logger.set_losses(losses) - - @contextlib.contextmanager - def disable_hooks(self): - try: - self._hooks_disabled = True - yield - finally: - self._hooks_disabled = False - - def remove_gptq_hooks(self): - for hook in self._hooks: - hook.remove() def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 8e87f3ee0..a94a8bf69 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,16 +1,19 @@ +import math +from copy import copy from typing import Tuple, Union -import time -import math import torch import transformers -from copy import copy -from loguru import logger -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering, fake_quantize +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, + fake_quantize, +) from compressed_tensors.quantization.observers import MovingAverageMinMaxObserver -from llmcompressor.pytorch.utils.helpers import tensor_sparsity -from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.pytorch.utils.helpers import tensor_sparsity GPTQ_PRECISION = torch.float32 @@ -21,7 +24,7 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: inp = inp.unsqueeze(0) nsamples = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length + # multiplied by the sequence length if module_class in (torch.nn.Linear, transformers.Conv1D): if len(inp.shape) == 3: @@ -43,7 +46,9 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H -def compute_scale_zeropoint(W: torch.Tensor, quant_args: QuantizationArgs) -> Tuple[torch.Tensor, torch.Tensor]: +def compute_scale_zeropoint( + W: torch.Tensor, quant_args: QuantizationArgs +) -> Tuple[torch.Tensor, torch.Tensor]: return MovingAverageMinMaxObserver(quant_args)(W) @@ -53,14 +58,16 @@ def quantize_weight( quant_args: QuantizationArgs, blocksize: int = 128, percdamp: float = 0.01, - module_class = torch.nn.Linear, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + module_class=torch.nn.Linear, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor +]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape final_dtype = weight.dtype W = weight.data.clone() - + H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype @@ -220,6 +227,7 @@ def quantize_weight( return losses, W, scale, zero_point, g_idx + def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index 6ebb1dc7a..fceb7fd75 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -1,10 +1,13 @@ +import time from typing import Any, Iterable, List, Tuple, Union -import time import torch from loguru import logger -from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes +from llmcompressor.utils.metric_logging import ( + get_GPU_memory_usage, + get_layer_size_bytes, +) __all__ = ["get_output_error", "gptq_hook", "MetricsLogger"] @@ -59,7 +62,7 @@ def gptq_hook(func): def wrapped(self, *args, **kwargs): if self._hooks_disabled: return - + func(self, *args, **kwargs) return wrapped diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 714d328df..b6fe73ce6 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,12 +1,14 @@ +import contextlib import operator -from typing import Dict, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Tuple, Union import torch from compressed_tensors import get_execution_device from loguru import logger -from torch.nn import Module from tqdm import tqdm +from llmcompressor.modifiers.quantization.gptq.utils.helpers import get_output_error from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device @@ -14,27 +16,123 @@ fix_fsdp_module_name, summon_full_params_context, ) -from llmcompressor.utils.pytorch import set_layer -from llmcompressor.utils.pytorch.module import get_prunable_layers +from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.pytorch.module import ( + get_layers, + get_no_split_params, + get_prunable_layers, + set_layer, +) __all__ = ["LayerCompressor"] -class LayerCompressorMixin: - def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): - return +class HooksMixin: + def __init__(self): + self.hooks = [] + self.hooks_disabled = False + + @classmethod + def hook(func): + def wrapped(self, *args, **kwargs): + if self.hooks_disabled: + return + + func(self, *args, **kwargs) + + return wrapped + + @contextlib.contextmanager + def disable_hooks(self): + try: + self._hooks_disabled = True + yield + finally: + self._hooks_disabled = False + + def remove_hooks(self): + for hook in self.hooks: + hook.remove() + + +class SequentialLayerCompressor(HooksMixin): + def __init__( + self, + compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], Any], + true_sequential: bool = True, + ): + self.compress_fn = compress_fn + self.true_sequential = true_sequential + + self._layer_index = 0 + self._num_layers = 0 + + def register_hooks( + self, model: torch.nn.Module, sequential_targets: Union[str, List[str], None] + ): + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + if self.sequential_targets is None: + self.sequential_targets = get_no_split_params(model) + layers = get_layers(sequential_targets, model) + self._num_layers = len(layers) + for name, module in model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append(module.register_forward_hook(post_hook)) if module in layers.values(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append( + module.register_forward_hook(post_hook, with_kwargs=True) + ) + + @HooksMixin.hook + def target_pre_forward(self, name: str, module: torch.nn.Module, args): + if self.true_sequential: + # compress first so output is from quantized weights + self.compress_fn(name, module, args) + + @HooksMixin.hook + def target_post_forward( + self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any + ): + if not self.true_sequential: + # compress after so output is from unquantized weights + self.compress_fn(name, module, args) + + @HooksMixin.hook + def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): + logger.info( + f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" + ) + + @HooksMixin.hook + def layer_post_forward( + self, + name: str, + module: torch.nn.Module, + args: torch.Tensor, + kwargs: Dict[str, Any], + output: Any, + ): + if not self.true_sequential: + # rerun with (now) compressed weights + with self.disable_hooks(): + compressed_output = module(*args, **kwargs) + + error = get_output_error(output, compressed_output) + logger.info(f"Mean output error from quantization: {error:.3f}") + + self._layer_index += 1 + return output class LayerCompressor: @@ -62,8 +160,8 @@ class LayerCompressor: def __init__( self, module_compressor_class: ModuleCompressionWrapper, - model: Module, - layer: Module, + model: torch.nn.Module, + layer: torch.nn.Module, layer_index: int, name: str, args: Dict, diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 20abaf376..c2f52a1cf 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -102,7 +102,7 @@ def run_calibration_forward( # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass - #torch.cuda.empty_cache() + # torch.cuda.empty_cache() return intermediates diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 14c724320..5891ab182 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -4,6 +4,7 @@ """ import ast +import contextlib import errno import fnmatch import glob @@ -15,23 +16,17 @@ import sys import tarfile import warnings -import contextlib from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlparse import numpy import torch +from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger -from compressed_tensors.quantization import ( - disable_quantization, - enable_quantization, -) -from compressed_tensors import is_module_offloaded - __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -1116,4 +1111,4 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) module._hf_hook.post_forward(module, torch.tensor([])) if device is not None: - module._hf_hook.execution_device = original_device \ No newline at end of file + module._hf_hook.execution_device = original_device From f65f8322633ec79de04f0cbad4e6ea763f21751e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 19:12:48 +0000 Subject: [PATCH 18/59] style --- .../modifiers/quantization/gptq/base.py | 5 ++-- src/llmcompressor/utils/helpers.py | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1b1e56f23..7e0d968dd 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -270,9 +270,8 @@ def collate_fn(batch): def quantize_module(self, name, module, args): logger.info(f"Compressing {name}...") - inp = args[ - 0 - ] # Assume that first argument is input (true for most Module types) + # Assume that first argument is input (true for most supported Module types) + inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 5891ab182..03abf18be 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -24,6 +24,7 @@ import numpy import torch +from compressed_tensors import is_module_offloaded from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger @@ -1102,13 +1103,25 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) :param device: optional device to move parameters to, if None is provided then module execution device will be used """ - if device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = device + if is_module_offloaded(module): + if device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = device - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, torch.tensor([])) + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, torch.tensor([])) + + if device is not None: + module._hf_hook.execution_device = original_device + + elif device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + setattr(module, name, param.to(device)) + + yield - if device is not None: - module._hf_hook.execution_device = original_device + for name, param_device in module.named_parameters: + setattr(module, name, param.to(param_device)) From 1e225692d9ccb7a3769d7e1dc724770f69cb7d92 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 19:55:15 +0000 Subject: [PATCH 19/59] use layer compressor --- examples/quantization_w4a16/llama3_example.py | 12 +++-- .../modifiers/quantization/gptq/base.py | 38 ++++++-------- .../quantization/gptq/utils/helpers.py | 52 +++++-------------- .../modifiers/utils/layer_compressor.py | 48 +++++++++-------- src/llmcompressor/utils/helpers.py | 3 ++ .../pruning/sparsegpt/test_pytorch.py | 16 +++--- 6 files changed, 73 insertions(+), 96 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 2568c59ed..96adcbfdc 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,8 +6,8 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -#MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( @@ -44,7 +44,9 @@ def preprocess(example): # Tokenize inputs. -tokenizer.add_special_tokens({'pad_token': '[PAD]'}) +tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + def tokenize(sample): return tokenizer( sample["text"], @@ -59,7 +61,9 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01) +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01 +) # Apply algorithms. oneshot( diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7e0d968dd..44dcdf194 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -13,7 +13,6 @@ from loguru import logger from pydantic import Field, field_validator from torch.nn.utils.rnn import pad_sequence -from torch.utils.hooks import RemovableHandle from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory @@ -22,7 +21,7 @@ ) from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor +from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.helpers import ( align_module, @@ -72,6 +71,7 @@ class GPTQModifier(Modifier): :param sequential_update: Whether or not to update weights sequentially by layer, True saves on GPU memory, default is True + :param true_sequential: TODO :param targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass @@ -102,7 +102,7 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True - true_sequential: bool = False + true_sequential: bool = True targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -114,11 +114,8 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None - _layer_index: int = 0 - _num_layers: int = 0 - _hooks_disabled: bool = False - quantization_modifier_: Optional[QuantizationModifier] = None - _hooks: List[RemovableHandle] = [] + _quantization_modifier: Optional[QuantizationModifier] = None + _layer_compressor: Optional[SequentialLayerCompressor] = None @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -134,10 +131,7 @@ def validate_sequential_update(cls, value: bool) -> bool: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._layer_index = 0 - self._num_layers = 0 - self.quantization_modifier_ = None - self.layer_compressor = LayerCompressor( + self._layer_compressor = SequentialLayerCompressor( self.quantize_module, self.true_sequential ) @@ -191,8 +185,8 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self.quantization_modifier_: - self.quantization_modifier_.on_initialize_structure(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.on_initialize_structure(state, **kwargs) def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -202,14 +196,14 @@ def on_initialize(self, state: "State", **kwargs) -> bool: """ if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) - if self.quantization_modifier_: - self.quantization_modifier_.initialize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch - self.layer_compressor.register_hooks(state.model, self.sequential_targets) + self._layer_compressor.register_hooks(state.model, self.sequential_targets) # apply calibration and trigger hooks (hooks are self removing) self.calibration_forward(state.model, state.data.calib) @@ -226,10 +220,10 @@ def on_finalize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ - if self.quantization_modifier_: - self.quantization_modifier_.finalize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.finalize(state, **kwargs) - self.layer_compressor.remove_hooks() + self._layer_compressor.remove_hooks() return True @@ -301,7 +295,7 @@ def _build_quant_modifier(self): Build a quantization modifier based on the specified config_groups, ignore list, and num_calibration_steps. - :postcondition: self.quantization_modifier_ is set to the built + :postcondition: self._quantization_modifier is set to the built quantization modifier """ @@ -327,7 +321,7 @@ def _build_quant_modifier(self): def _build_quant_modifier_from_dict(self, quant_config): modifier_type = list(quant_config.keys())[0] modifier_args = quant_config[modifier_type] - self.quantization_modifier_ = ModifierFactory.create( + self._quantization_modifier = ModifierFactory.create( modifier_type, allow_registered=True, allow_experimental=True, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index fceb7fd75..a369e0d4c 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -13,51 +13,23 @@ def get_output_error( - unquantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], - quantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], + uncompressed: Tuple[torch.Tensor, ...], + compressed: Tuple[torch.Tensor, ...], ) -> torch.Tensor: """ - Calculate mean l1 loss between weight-unquantized outputs and weight-quantized - outputs + Calculate mean absolute error between weight-uncompressed outputs and + weight-compressed outputs - :param unquantized: unquantized-weight outputs - :param quantized: quantized-weight outputs - :return: mean l1 loss between outputs + :param uncompressed: uncompressed-weight outputs + :param compressed: compressed-weight outputs + :return: mean absolute error between outputs """ - unquantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in unquantized - ], - start=[], - ) - - quantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in quantized - ], - start=[], - ) - - if len(unquantized_outputs) != len(quantized_outputs): - raise ValueError( - "Number of samples of weight-unquantized and weight-quantized " - "outputs differs" - ) - - return sum( - [ - torch.nn.functional.l1_loss(unq, q) - for unq, q in zip(unquantized_outputs, quantized_outputs) - ] - ) / len(unquantized_outputs) - + # assume first output is the the relevant output (true for most Modules) + uncompressed = uncompressed[0] + compressed = compressed[0] + return torch.mean(torch.abs(uncompressed - compressed)) + def gptq_hook(func): def wrapped(self, *args, **kwargs): if self._hooks_disabled: diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index b6fe73ce6..a2bdf0582 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -28,30 +28,35 @@ class HooksMixin: - def __init__(self): - self.hooks = [] - self.hooks_disabled = False + HOOKS_DISABLED: bool = False @classmethod - def hook(func): - def wrapped(self, *args, **kwargs): - if self.hooks_disabled: + def hook(cls, func): + def wrapped(*args, **kwargs): + if cls.HOOKS_DISABLED: return - func(self, *args, **kwargs) + func(*args, **kwargs) return wrapped + @classmethod @contextlib.contextmanager - def disable_hooks(self): + def disable_hooks(cls): try: - self._hooks_disabled = True + cls.HOOKS_DISABLED = True yield finally: - self._hooks_disabled = False + cls.HOOKS_DISABLED = False + + def __init__(self): + self._hooks = [] + + def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + self._hooks.append(handle) def remove_hooks(self): - for hook in self.hooks: + for hook in self._hooks: hook.remove() @@ -61,6 +66,7 @@ def __init__( compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], Any], true_sequential: bool = True, ): + HooksMixin.__init__(self) self.compress_fn = compress_fn self.true_sequential = true_sequential @@ -74,8 +80,8 @@ def register_hooks( # if no targets are provided, default to the modules that shouldn't be # split by FSDP. For Transformers models this is equivalent to the # decoder layers (ie LlamaDecoderLayer) - if self.sequential_targets is None: - self.sequential_targets = get_no_split_params(model) + if sequential_targets is None: + sequential_targets = get_no_split_params(model) layers = get_layers(sequential_targets, model) self._num_layers = len(layers) @@ -83,16 +89,14 @@ def register_hooks( if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append(module.register_forward_hook(post_hook)) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_hook(post_hook)) - if module in layers.values(): + if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append( - module.register_forward_hook(post_hook, with_kwargs=True) - ) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_hook(post_hook, with_kwargs=True)) @HooksMixin.hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -121,11 +125,11 @@ def layer_post_forward( module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], - output: Any, + output: Tuple[torch.Tensor, ...], ): if not self.true_sequential: # rerun with (now) compressed weights - with self.disable_hooks(): + with HooksMixin.disable_hooks(): compressed_output = module(*args, **kwargs) error = get_output_error(output, compressed_output) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 03abf18be..c414d134b 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1125,3 +1125,6 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) for name, param_device in module.named_parameters: setattr(module, name, param.to(param_device)) + + else: + yield diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 5421af4cf..3cdc25038 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -75,15 +75,15 @@ def test_create_default_quant_modifier(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - modifier.quantization_modifier_.create_init_config() + assert isinstance(modifier._quantization_modifier, QuantizationModifier) + modifier._quantization_modifier.create_init_config() default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ + should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ default_config_group_name ] assert should_be_default_quant_scheme.input_activations is None @@ -113,7 +113,7 @@ def test_set_quant_if_modifer_already_exists(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert not modifier.quantization_modifier_ + assert not modifier._quantization_modifier modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in @@ -150,14 +150,14 @@ def test_set_quant_in_gptq(self): kwargs = dict(block_size=128, quantize=self.quant_config) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) - dict_scheme = dict(modifier.quantization_modifier_.config_groups) + dict_scheme = dict(modifier._quantization_modifier.config_groups) self._check_config( dict(dict_scheme["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"], From 9324695131e30f3755a6b2d64be6c6a3d0731ac5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 21:16:40 +0000 Subject: [PATCH 20/59] replicate dtypes --- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a94a8bf69..a193ae817 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -49,7 +49,12 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: def compute_scale_zeropoint( W: torch.Tensor, quant_args: QuantizationArgs ) -> Tuple[torch.Tensor, torch.Tensor]: - return MovingAverageMinMaxObserver(quant_args)(W) + # TODO: revisit after observers refactor + + scale, zero_point = quant_args.get_observer()(W, g_idx=None) + scale = scale.to(dtype=W.dtype) + zero_point = zero_point.to(dtype=quant_args.pytorch_dtype()) + return scale, zero_point def quantize_weight( From eef4fb6f666b688f1f3b936a6efe8c170a644908 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 21:21:30 +0000 Subject: [PATCH 21/59] write weight changes --- src/llmcompressor/modifiers/quantization/gptq/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 44dcdf194..7e0617bdb 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -102,7 +102,7 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True - true_sequential: bool = True + true_sequential: bool = False targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -284,6 +284,7 @@ def quantize_module(self, name, module, args): if is_module_offloaded(module): update_prefix_dict(self.layer, "weight", weight) + update_parameter_data(module, weight, "weight") update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") From 485813a6d73a4b0749d2fc44f839243423601269 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:04:39 +0000 Subject: [PATCH 22/59] revert example --- examples/quantization_w4a16/llama3_example.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 96adcbfdc..939991ab6 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,4 +1,3 @@ -import torch from datasets import load_dataset from transformers import AutoTokenizer @@ -6,9 +5,7 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -# MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, @@ -23,8 +20,8 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 // 6 -MAX_SEQUENCE_LENGTH = 2048 // 2 +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) @@ -44,13 +41,10 @@ def preprocess(example): # Tokenize inputs. -tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - def tokenize(sample): return tokenizer( sample["text"], - padding=True, + padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, @@ -61,9 +55,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01 -) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) # Apply algorithms. oneshot( From 60061551514efc3f15f1fbd6e38e033703252bb5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:33:16 +0000 Subject: [PATCH 23/59] organization --- .../modifiers/quantization/gptq/base.py | 25 +++--- .../quantization/gptq/utils/gptq_quantize.py | 8 +- .../quantization/gptq/utils/helpers.py | 84 ------------------- .../modifiers/utils/layer_compressor.py | 24 ++++-- src/llmcompressor/utils/metric_logging.py | 54 +++++++++++- 5 files changed, 82 insertions(+), 113 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7e0617bdb..2917c85c4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from compressed_tensors.quantization import ( @@ -19,7 +19,6 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( quantize_weight, ) -from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -261,16 +260,17 @@ def collate_fn(batch): with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) - def quantize_module(self, name, module, args): - logger.info(f"Compressing {name}...") + def quantize_module( + self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...] + ) -> float: + logger.info(f"Quantizing {name}...") - # Assume that first argument is input (true for most supported Module types) + # Assume that first argument is the input inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - # with onloaded weight - with align_module(module), MetricsLogger(module) as metrics_logger: - losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( + with align_module(module): + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, quant_args, @@ -279,17 +279,16 @@ def quantize_module(self, name, module, args): module_class=type(module), ) - # weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) - weight = quantized_weight + # FUTURE: Implement learning rate modification to weight update if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", weight) - update_parameter_data(module, weight, "weight") + update_prefix_dict(self.layer, "weight", quantized_weight) + update_parameter_data(module, quantized_weight, "weight") update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") - metrics_logger.set_losses(losses) + return loss def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a193ae817..b21956cee 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -10,7 +10,6 @@ QuantizationStrategy, fake_quantize, ) -from compressed_tensors.quantization.observers import MovingAverageMinMaxObserver from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD from llmcompressor.pytorch.utils.helpers import tensor_sparsity @@ -64,9 +63,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class=torch.nn.Linear, -) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor -]: +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -230,7 +227,8 @@ def quantize_weight( W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - return losses, W, scale, zero_point, g_idx + loss = torch.sum(losses).item() + return loss, W, scale, zero_point, g_idx def _apply_activation_ordering( diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py deleted file mode 100644 index a369e0d4c..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ /dev/null @@ -1,84 +0,0 @@ -import time -from typing import Any, Iterable, List, Tuple, Union - -import torch -from loguru import logger - -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) - -__all__ = ["get_output_error", "gptq_hook", "MetricsLogger"] - - -def get_output_error( - uncompressed: Tuple[torch.Tensor, ...], - compressed: Tuple[torch.Tensor, ...], -) -> torch.Tensor: - """ - Calculate mean absolute error between weight-uncompressed outputs and - weight-compressed outputs - - :param uncompressed: uncompressed-weight outputs - :param compressed: compressed-weight outputs - :return: mean absolute error between outputs - """ - # assume first output is the the relevant output (true for most Modules) - uncompressed = uncompressed[0] - compressed = compressed[0] - - return torch.mean(torch.abs(uncompressed - compressed)) - -def gptq_hook(func): - def wrapped(self, *args, **kwargs): - if self._hooks_disabled: - return - - func(self, *args, **kwargs) - - return wrapped - - -class MetricsLogger: - def __init__(self, module: torch.nn.Module): - self.module = module - self.start_tick = None - self.losses = None - - def set_losses(self, losses: torch.Tensor): - self.losses = losses - - def __enter__(self) -> "MetricsLogger": - self.start_tick = time.time() - return self - - def __exit__(self, _exc_type, _exc_val, _exc_tb): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - - if self.start_tick is not None: - patch.log("METRIC", "time %.2f" % (time.time() - self.start_tick)) - if self.losses is not None: - patch.log("METRIC", "error %.2f" % torch.sum(self.losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - compressed_size = get_layer_size_bytes(self.module) - patch.log("METRIC", f"Compressed layer size: {compressed_size} MB") diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index a2bdf0582..5bb459372 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -8,7 +8,6 @@ from loguru import logger from tqdm import tqdm -from llmcompressor.modifiers.quantization.gptq.utils.helpers import get_output_error from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device @@ -17,6 +16,7 @@ summon_full_params_context, ) from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -24,7 +24,7 @@ set_layer, ) -__all__ = ["LayerCompressor"] +__all__ = ["SequentialLayerCompressor", "LayerCompressor"] class HooksMixin: @@ -63,7 +63,7 @@ def remove_hooks(self): class SequentialLayerCompressor(HooksMixin): def __init__( self, - compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], Any], + compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], float], true_sequential: bool = True, ): HooksMixin.__init__(self) @@ -96,21 +96,27 @@ def register_hooks( pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) self.register_hook(module.register_forward_pre_hook(pre_hook)) - self.register_hook(module.register_forward_hook(post_hook, with_kwargs=True)) + self.register_hook( + module.register_forward_hook(post_hook, with_kwargs=True) + ) @HooksMixin.hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): if self.true_sequential: - # compress first so output is from quantized weights - self.compress_fn(name, module, args) + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_fn(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def target_post_forward( self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any ): if not self.true_sequential: - # compress after so output is from unquantized weights - self.compress_fn(name, module, args) + # compress after so output is from uncompressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_fn(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): @@ -132,7 +138,7 @@ def layer_post_forward( with HooksMixin.disable_hooks(): compressed_output = module(*args, **kwargs) - error = get_output_error(output, compressed_output) + error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) logger.info(f"Mean output error from quantization: {error:.3f}") self._layer_index += 1 diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index d0b3bb11e..b23ba200a 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -1,7 +1,10 @@ +import time from typing import List, Tuple +import torch from loguru import logger -from torch.nn import Module + +__all__ = ["CompressionLogger"] def get_GPU_memory_usage() -> List[Tuple]: @@ -35,7 +38,7 @@ def get_GPU_memory_usage() -> List[Tuple]: return [] -def get_layer_size_bytes(module: Module) -> float: +def get_module_size_bytes(module: torch.nn.Module) -> float: param_size = 0 buffer_size = 0 @@ -49,3 +52,50 @@ def get_layer_size_bytes(module: Module) -> float: total_size_mb = total_size / (1024**2) # Convert bytes to MB return total_size_mb + + +class CompressionLogger: + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + + def __init__(self, module: torch.nn.Module): + self.module = module + self.start_tick = None + self.loss = None + + def set_loss(self, loss: float): + self.loss = loss + + def __enter__(self) -> "CompressionLogger": + self.start_tick = time.time() + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + stop_tick = time.time() + patch = logger.patch(lambda r: r.update(function="compress")) + + if self.start_tick is not None: + duration = stop_tick - self.start_tick + patch.log("METRIC", f"time {duration:.2f}") + if self.loss is not None: + patch.log("METRIC", f"error {self.loss:.2f}") + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + compressed_size = get_module_size_bytes(self.module) + patch.log("METRIC", f"Compressed module size: {compressed_size} MB") From c10d2ee3d9f82d68926fe1592f67a9dd2b798bbc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:50:40 +0000 Subject: [PATCH 24/59] add create_single_batch_dataloader --- .../modifiers/quantization/gptq/base.py | 33 +++-------------- .../quantization/gptq/utils/__init__.py | 1 - .../finetune/data/data_helpers.py | 35 +++++++++++++++++-- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 2917c85c4..1b03130c7 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -12,7 +12,6 @@ ) from loguru import logger from pydantic import Field, field_validator -from torch.nn.utils.rnn import pad_sequence from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory @@ -22,6 +21,9 @@ from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.transformers.finetune.data.data_helpers import ( + create_single_batch_dataloader, +) from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -229,34 +231,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader ): - dataset = dataloader.dataset - - def collate_fn(batch): - # extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item["input_ids"]) for item in batch] - attention_masks = [torch.tensor(item["attention_mask"]) for item in batch] - - # pad sequences in the batch - padded_input_ids = pad_sequence( - input_ids, batch_first=True, padding_value=0 - ) - padded_attention_masks = pad_sequence( - attention_masks, batch_first=True, padding_value=0 - ) - - return { - "input_ids": padded_input_ids, - "attention_mask": padded_attention_masks, - } - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=len(dataset), - shuffle=True, - collate_fn=collate_fn, - pin_memory=True, - ) - + dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py index 5703ced46..ec39da973 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa from .gptq_quantize import * -from .helpers import * diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..933f64bd9 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,9 +1,11 @@ import logging import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Optional +import datasets import torch from datasets import Dataset, load_dataset +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator @@ -11,20 +13,49 @@ LABELS_MASK_VALUE = -100 __all__ = [ + "create_single_batch_dataloader", "format_calibration_data", "get_raw_dataset", "make_dataset_splits", "get_custom_datasets_from_path", + "LABELS_MASK_VALUE", ] +def create_single_batch_dataloader( + dataset: datasets.Dataset, +) -> torch.utils.data.DataLoader: + def pad_sequences(batch): + # extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item["input_ids"]) for item in batch] + masks = [torch.tensor(item["attention_mask"]) for item in batch] + + # while 0 is not necessarily the "correct" padding value, the padded + # input_ids are ignored according to the attention_mask + pad_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + pad_masks = pad_sequence(masks, batch_first=True, padding_value=0) + + return { + "input_ids": pad_input_ids, + "attention_mask": pad_masks, + } + + return torch.utils.data.DataLoader( + dataset, + batch_size=len(dataset), + shuffle=True, + collate_fn=pad_sequences, + pin_memory=True, + ) + + def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, do_shuffle: bool = True, collate_fn: Callable = default_data_collator, accelerator: Optional[Any] = None, -) -> List[torch.Tensor]: +) -> torch.utils.data.DataLoader: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples From 637119322b4cf2202b34a7f7385d5102fe5ff013 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:54:13 +0000 Subject: [PATCH 25/59] add back empty_cache until I can justify removing it --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1b03130c7..b59085dab 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -206,7 +206,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # after lifecycle refactor, move this to pre_batch self._layer_compressor.register_hooks(state.model, self.sequential_targets) - # apply calibration and trigger hooks (hooks are self removing) + # apply calibration and trigger hooks self.calibration_forward(state.model, state.data.calib) # freeze quantization diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index c2f52a1cf..9003ff22d 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -102,7 +102,7 @@ def run_calibration_forward( # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass - # torch.cuda.empty_cache() + torch.cuda.empty_cache() return intermediates From 92315a5197e9321e6d9bd3667b9505e9a527027c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:26:56 +0000 Subject: [PATCH 26/59] better type hinting, faster mask applying --- src/llmcompressor/modifiers/utils/layer_compressor.py | 11 ++++++----- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 5bb459372..833c8176d 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -101,7 +101,8 @@ def register_hooks( ) @HooksMixin.hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args): + def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any, ...]): + breakpoint() if self.true_sequential: # compress first so output is from compressed weights with CompressionLogger(module) as comp_logger: @@ -110,7 +111,7 @@ def target_pre_forward(self, name: str, module: torch.nn.Module, args): @HooksMixin.hook def target_post_forward( - self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], _output: Tuple[Any, ...] ): if not self.true_sequential: # compress after so output is from uncompressed weights @@ -119,7 +120,7 @@ def target_post_forward( comp_logger.set_loss(loss) @HooksMixin.hook - def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): + def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) @@ -129,9 +130,9 @@ def layer_post_forward( self, name: str, module: torch.nn.Module, - args: torch.Tensor, + args: Tuple[Any, ...], kwargs: Dict[str, Any], - output: Tuple[torch.Tensor, ...], + output: Tuple[Any, ...], ): if not self.true_sequential: # rerun with (now) compressed weights diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9003ff22d..43b261d99 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -39,7 +39,7 @@ def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.T :param batch: batch to apply padding to if it exists :return: batch with padding zeroed out in the input_ids """ - batch["input_ids"] = batch["input_ids"] * batch["attention_mask"] + batch["input_ids"].masked_fill_(batch["attention_mask"] == 0, 0) return batch From 8a25c68438a487b423c8b872b3b42d7a30b36b4d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:37:59 +0000 Subject: [PATCH 27/59] remove breakpoint --- src/llmcompressor/modifiers/utils/layer_compressor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 833c8176d..8fd933a1d 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -102,7 +102,6 @@ def register_hooks( @HooksMixin.hook def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any, ...]): - breakpoint() if self.true_sequential: # compress first so output is from compressed weights with CompressionLogger(module) as comp_logger: From 6cd0d6cc1255fea96fd99599cc09b5e72c44bdb0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:41:10 +0000 Subject: [PATCH 28/59] apply style, add true_sequential docstring --- src/llmcompressor/modifiers/quantization/gptq/base.py | 7 +++++-- src/llmcompressor/modifiers/utils/layer_compressor.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 04260340c..0958602da 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -11,7 +12,6 @@ update_prefix_dict, ) from loguru import logger -import warnings from pydantic import Field, field_validator from llmcompressor.core import State @@ -72,7 +72,10 @@ class GPTQModifier(Modifier): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param true_sequential: TODO + :param true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block) :param targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 8fd933a1d..c2130068f 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -101,7 +101,9 @@ def register_hooks( ) @HooksMixin.hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any, ...]): + def target_pre_forward( + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] + ): if self.true_sequential: # compress first so output is from compressed weights with CompressionLogger(module) as comp_logger: @@ -110,7 +112,11 @@ def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any @HooksMixin.hook def target_post_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], _output: Tuple[Any, ...] + self, + name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + _output: Tuple[Any, ...], ): if not self.true_sequential: # compress after so output is from uncompressed weights From 0e0c586c4773d5b9b2176d4d0688468361bbc94a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:42:25 +0000 Subject: [PATCH 29/59] update docstring --- src/llmcompressor/modifiers/quantization/gptq/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0958602da..22404f0da 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -12,6 +11,7 @@ update_prefix_dict, ) from loguru import logger +import warnings from pydantic import Field, field_validator from llmcompressor.core import State @@ -53,6 +53,7 @@ class GPTQModifier(Modifier): | test_stage: | obcq_modifiers: | GPTQModifier: + | true_sequential: False | dampening_frac: 0.001 | block_size: 128 | config_groups: @@ -75,7 +76,7 @@ class GPTQModifier(Modifier): :param true_sequential: Used to control the granularity of compression updates through the forward pass. Set to True to use the weight-compressed outputs of each module, set to False to use the weight-compressed outputs of each - layer (transformer block) + layer (transformer block), defaults to False :param targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass From d23aabb1330aa408b3374eb79341e08a9b0e3f7b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:46:58 +0000 Subject: [PATCH 30/59] use private attrs --- src/llmcompressor/modifiers/quantization/gptq/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 22404f0da..05a7ab0a6 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -12,7 +12,7 @@ ) from loguru import logger import warnings -from pydantic import Field, field_validator +from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory @@ -119,8 +119,8 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None - _quantization_modifier: Optional[QuantizationModifier] = None - _layer_compressor: Optional[SequentialLayerCompressor] = None + _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() + _layer_compressor: Optional[SequentialLayerCompressor] = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: From 355074b2bf815ac9a06f30617c50b15d9ccd4364 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:08:23 +0000 Subject: [PATCH 31/59] more docstring --- .../modifiers/quantization/gptq/base.py | 2 +- .../modifiers/utils/layer_compressor.py | 27 ++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 05a7ab0a6..0f4f7de95 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -11,7 +12,6 @@ update_prefix_dict, ) from loguru import logger -import warnings from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index c2130068f..e3e5c6217 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,7 +1,7 @@ import contextlib import operator from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from compressed_tensors import get_execution_device @@ -61,10 +61,29 @@ def remove_hooks(self): class SequentialLayerCompressor(HooksMixin): + """ + Apply a given compression function to a model during the model's calibration + forward pass + + Lifecycle: + - QuantizationModifier.initialize(model) + - SequentialLayerCompressor(compress_fn) + - register_hooks(model) + - model.forward() + - compress_fn(name, target_module, args) + - remove_hooks() + + :param compress_fn: Function to be called on target modules + :param true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block), defaults to False + """ + def __init__( self, compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], float], - true_sequential: bool = True, + true_sequential: bool = False, ): HooksMixin.__init__(self) self.compress_fn = compress_fn @@ -74,7 +93,9 @@ def __init__( self._num_layers = 0 def register_hooks( - self, model: torch.nn.Module, sequential_targets: Union[str, List[str], None] + self, + model: torch.nn.Module, + sequential_targets: Optional[Union[str, List[str]]] = None, ): # find layers (used for printing even if true_sequential=True) # if no targets are provided, default to the modules that shouldn't be From bf2184d60bcc8671b3d6d6e587ba0dcfe75504da Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:18:43 +0000 Subject: [PATCH 32/59] docstrings --- .../modifiers/quantization/gptq/base.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0f4f7de95..2c31d0a70 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -235,6 +235,13 @@ def on_finalize(self, state: "State", **kwargs) -> bool: def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader ): + """ + Perform calibration forward pass with one batch whose size is the size + of the dataset + + :param model: model to perform forward pass with + :param dataloader: dataloader containing calibration dataset + """ dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) @@ -242,6 +249,15 @@ def calibration_forward( def quantize_module( self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...] ) -> float: + """ + Quantize a module's weight according to the GPTQ algorithm + + :param name: name of module being quantized + :param module: module being quantized + :param args: input arguments for module forward pass + + :return: total loss from applying weight quantization to this module + """ logger.info(f"Quantizing {name}...") # Assume that first argument is the input From 0b418c7efe2a6810179d942030dd40a0e7c38148 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:27:14 +0000 Subject: [PATCH 33/59] docstrings --- .../quantization/gptq/utils/gptq_quantize.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index b21956cee..4301d944a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -18,6 +18,13 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: + """ + Calculate the hessian with respect to the module inputs + + :param inp: module inputs + :param module_class: class of module, likely torch.nn.Linear + :return: hessian w.r.t. module inputs + """ inp = inp.to(device=device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) @@ -36,6 +43,13 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: + """ + Performs in-place inversion of the hessian in order to save memory + + :param H: hessian being inverted + :param percdamp: dampening factor on hessian diagonal + :return: inverted hessian + """ damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(H.shape[0], device=H.device) H[diag, diag] += damp @@ -45,9 +59,18 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H -def compute_scale_zeropoint( +def compute_scale_zero_point( W: torch.Tensor, quant_args: QuantizationArgs ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the scale and zero point of a module weight + TODO: revisit after observers refactor + + :param W: module weight + :param quant_args: quantization arguments which determine how quantization + parameters are calculated + :return: scale and zero_point + """ # TODO: revisit after observers refactor scale, zero_point = quant_args.get_observer()(W, g_idx=None) @@ -64,6 +87,17 @@ def quantize_weight( percdamp: float = 0.01, module_class=torch.nn.Linear, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + :param weight: weight being quantized + :param inp: module inputs used to calculate hessian + :param quant_args: quantization arguments used to find quantization parameters + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :param module_class: class of module, likely torch.nn.Linear + :return: loss, quantized_weight, scale, zero_point, g_idx + """ strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -91,22 +125,22 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] else: - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) else: - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) # sparsity mask sparsity = tensor_sparsity(W) @@ -238,6 +272,8 @@ def _apply_activation_ordering( Permute weight and hessian in order of greatest outupt activations :param W: weight to permute + :param H: hessian used to determine activation ordering + :return: permuted weight, permuted hessian, permutation map """ perm = torch.argsort(torch.diag(H), descending=True) return W[:, perm], H[perm][:, perm], perm From 56cceeaccb6cd0ac54a39b53c678751ad807ccd5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:31:58 +0000 Subject: [PATCH 34/59] docstrings --- .../transformers/finetune/data/data_helpers.py | 6 ++++++ src/llmcompressor/utils/helpers.py | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 933f64bd9..8a6f097a3 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -25,6 +25,12 @@ def create_single_batch_dataloader( dataset: datasets.Dataset, ) -> torch.utils.data.DataLoader: + """ + Create a dataloader whose batch size is equal to the size of the dataset + + :param dataset: dataset used to generate dataloader + :return: dataloader + """ def pad_sequences(batch): # extract input_ids and attention_mask from the batch input_ids = [torch.tensor(item["input_ids"]) for item in batch] diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index c414d134b..211bd01eb 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1077,6 +1077,9 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): @contextlib.contextmanager def DisableQuantization(model: torch.nn.Module): + """ + Disable quantization from QuantizationModifier + """ model.apply(disable_quantization) yield model.apply(enable_quantization) @@ -1084,6 +1087,13 @@ def DisableQuantization(model: torch.nn.Module): @contextlib.contextmanager def calibration_forward_context(model: torch.nn.Module): + """ + Context in which all calibration forward passes should occur. + + - Remove gradient calculations + - Disable the KV cache + - Disable quantization from QuantizationModifier + """ model.eval() with ( From 7c7e3bc964921384472bb7f24d48e0759cc3e610 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:33:44 +0000 Subject: [PATCH 35/59] move hooksmixin to separate file --- .../quantization/gptq/utils/gptq_quantize.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 38 +++++++++++++++++++ .../modifiers/utils/layer_compressor.py | 35 +---------------- .../finetune/data/data_helpers.py | 3 +- 4 files changed, 42 insertions(+), 36 deletions(-) create mode 100644 src/llmcompressor/modifiers/utils/hooks.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 4301d944a..022252b0a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -20,7 +20,7 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs - + :param inp: module inputs :param module_class: class of module, likely torch.nn.Linear :return: hessian w.r.t. module inputs diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py new file mode 100644 index 000000000..d7e35015f --- /dev/null +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -0,0 +1,38 @@ +import contextlib + +import torch + +__all__ = ["HooksMixin"] + + +class HooksMixin: + HOOKS_DISABLED: bool = False + + @classmethod + def hook(cls, func): + def wrapped(*args, **kwargs): + if cls.HOOKS_DISABLED: + return + + func(*args, **kwargs) + + return wrapped + + @classmethod + @contextlib.contextmanager + def disable_hooks(cls): + try: + cls.HOOKS_DISABLED = True + yield + finally: + cls.HOOKS_DISABLED = False + + def __init__(self): + self._hooks = [] + + def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + self._hooks.append(handle) + + def remove_hooks(self): + for hook in self._hooks: + hook.remove() diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index e3e5c6217..b168e2534 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,4 +1,3 @@ -import contextlib import operator from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -9,6 +8,7 @@ from tqdm import tqdm from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.utils.fsdp.context import ( @@ -27,39 +27,6 @@ __all__ = ["SequentialLayerCompressor", "LayerCompressor"] -class HooksMixin: - HOOKS_DISABLED: bool = False - - @classmethod - def hook(cls, func): - def wrapped(*args, **kwargs): - if cls.HOOKS_DISABLED: - return - - func(*args, **kwargs) - - return wrapped - - @classmethod - @contextlib.contextmanager - def disable_hooks(cls): - try: - cls.HOOKS_DISABLED = True - yield - finally: - cls.HOOKS_DISABLED = False - - def __init__(self): - self._hooks = [] - - def register_hook(self, handle: torch.utils.hooks.RemovableHandle): - self._hooks.append(handle) - - def remove_hooks(self): - for hook in self._hooks: - hook.remove() - - class SequentialLayerCompressor(HooksMixin): """ Apply a given compression function to a model during the model's calibration diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 8a6f097a3..cc1c946ac 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -27,10 +27,11 @@ def create_single_batch_dataloader( ) -> torch.utils.data.DataLoader: """ Create a dataloader whose batch size is equal to the size of the dataset - + :param dataset: dataset used to generate dataloader :return: dataloader """ + def pad_sequences(batch): # extract input_ids and attention_mask from the batch input_ids = [torch.tensor(item["input_ids"]) for item in batch] From 2d52183760cebb80a9a87ce0c3e3a07796c799a7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:39:46 +0000 Subject: [PATCH 36/59] docstrings --- src/llmcompressor/modifiers/utils/hooks.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index d7e35015f..19c9a34ce 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -6,6 +6,15 @@ class HooksMixin: + """" + Class to manage the registration, disabling, and removal of hooks. Registering + and removing hooks should be handled by modifier classes which inherit from this + mixin, while disabling hooks should disable all hooks across modifiers. + + Modifiers which implement hooks should use the @HooksMixin.hook decorator + Modifiers must pass registered hooks handles to self.register_hook() and must + remove hooks when finished using self.remove_hooks() + """ HOOKS_DISABLED: bool = False @classmethod @@ -21,6 +30,10 @@ def wrapped(*args, **kwargs): @classmethod @contextlib.contextmanager def disable_hooks(cls): + """ + Disable all hooks across all modifiers + TODO: select which modifier hooks are disabled/ kept enabled + """ try: cls.HOOKS_DISABLED = True yield @@ -31,8 +44,16 @@ def __init__(self): self._hooks = [] def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + """ + Usage: self.register_hook(module.register_forward_hook(...)) + + :param handle: handle of added hook + """ self._hooks.append(handle) def remove_hooks(self): + """ + Remove all hooks belonging to a modifier + """ for hook in self._hooks: hook.remove() From 9081f12f0239860344fe0fd5d86988db3cd88ecb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 01:18:33 -0400 Subject: [PATCH 37/59] fix docstring, better arguments grouping --- .../modifiers/quantization/gptq/base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 2c31d0a70..15f16727e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -77,8 +77,8 @@ class GPTQModifier(Modifier): through the forward pass. Set to True to use the weight-compressed outputs of each module, set to False to use the weight-compressed outputs of each layer (transformer block), defaults to False - :param targets: list of layer names to compress during GPTQ, or '__ALL__' - to compress every layer in the model + :param sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not @@ -108,16 +108,18 @@ class GPTQModifier(Modifier): sequential_update: bool = True # DEPRECIATED true_sequential: bool = False - targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 - quantize: Union[bool, Dict] = True dampening_frac: Optional[float] = 0.01 + quantize: Union[bool, Dict] = True + + # arguments used for quant modifier config_groups: Optional[Dict[str, QuantizationScheme]] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None + targets: Union[str, List[str], None] = None ignore: List[str] = Field(default_factory=list) - disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None - scheme: Optional[Union[str, Dict[str, Any]]] = None + disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() _layer_compressor: Optional[SequentialLayerCompressor] = PrivateAttr() From 96e9496f266059494e0e88fd190606bfc5835ccd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 24 Oct 2024 03:53:00 +0000 Subject: [PATCH 38/59] use LayerCompressorMixin --- .../modifiers/quantization/gptq/base.py | 23 ++- src/llmcompressor/modifiers/utils/hooks.py | 148 ++++++++++++++++-- .../modifiers/utils/layer_compressor.py | 127 +-------------- 3 files changed, 148 insertions(+), 150 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 15f16727e..177db29e4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -20,7 +20,7 @@ quantize_weight, ) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor +from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.transformers.finetune.data.data_helpers import ( create_single_batch_dataloader, @@ -35,7 +35,7 @@ __all__ = ["GPTQModifier"] -class GPTQModifier(Modifier): +class GPTQModifier(Modifier, LayerCompressorMixin): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -122,7 +122,6 @@ class GPTQModifier(Modifier): disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() - _layer_compressor: Optional[SequentialLayerCompressor] = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -135,13 +134,6 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._layer_compressor = SequentialLayerCompressor( - self.quantize_module, self.true_sequential - ) - def on_initialize_structure(self, state: State, **kwargs): """ Check the model's quantization state matches that expected by this modifier, @@ -210,7 +202,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch - self._layer_compressor.register_hooks(state.model, self.sequential_targets) + self.register_hooks(state.model) # apply calibration and trigger hooks self.calibration_forward(state.model, state.data.calib) @@ -230,7 +222,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) - self._layer_compressor.remove_hooks() + self.remove_hooks() return True @@ -248,8 +240,11 @@ def calibration_forward( with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) - def quantize_module( - self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...] + def compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], ) -> float: """ Quantize a module's weight according to the GPTQ algorithm diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 19c9a34ce..d65e41d01 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,12 +1,22 @@ import contextlib +from abc import abstractmethod +from functools import partial +from typing import Any, Callable, ClassVar, Dict, List, Tuple import torch +from loguru import logger +from pydantic import BaseModel +from torch.utils.hooks import RemovableHandle -__all__ = ["HooksMixin"] +from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.metric_logging import CompressionLogger +from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params +__all__ = ["HooksMixin", "LayerCompressorMixin"] -class HooksMixin: - """" + +class HooksMixin(BaseModel): + """ " Class to manage the registration, disabling, and removal of hooks. Registering and removing hooks should be handled by modifier classes which inherit from this mixin, while disabling hooks should disable all hooks across modifiers. @@ -15,12 +25,14 @@ class HooksMixin: Modifiers must pass registered hooks handles to self.register_hook() and must remove hooks when finished using self.remove_hooks() """ - HOOKS_DISABLED: bool = False + + _HOOKS_DISABLED: ClassVar[bool] = False + _hooks: List[RemovableHandle] = [] @classmethod - def hook(cls, func): + def hook(cls, func: Callable[[Any], Any]): def wrapped(*args, **kwargs): - if cls.HOOKS_DISABLED: + if cls._HOOKS_DISABLED: return func(*args, **kwargs) @@ -35,15 +47,12 @@ def disable_hooks(cls): TODO: select which modifier hooks are disabled/ kept enabled """ try: - cls.HOOKS_DISABLED = True + cls._HOOKS_DISABLED = True yield finally: - cls.HOOKS_DISABLED = False - - def __init__(self): - self._hooks = [] + cls._HOOKS_DISABLED = False - def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + def register_hook(self, handle: RemovableHandle): """ Usage: self.register_hook(module.register_forward_hook(...)) @@ -57,3 +66,118 @@ def remove_hooks(self): """ for hook in self._hooks: hook.remove() + + +class LayerCompressorMixin(HooksMixin): + """ + Apply a given compression function to a model during the model's calibration + forward pass + + Lifecycle: + - QuantizationModifier.initialize(model) + - SequentialLayerCompressor(compress_fn) + - register_hooks(model) + - model.forward() + - compress_fn(name, target_module, args) + - remove_hooks() + + :ivar true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block), defaults to False + :ivar sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model + :ivar compresss_module: Function to be called on target modules + """ + + true_sequential: bool + sequential_targets: bool + # compress_module: Callable[[str, torch.nn.Module, Tuple], float] + + _layer_index = 0 + _num_layers = 0 + + @abstractmethod + def compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: + raise NotImplementedError() + + def register_hooks(self, model: torch.nn.Module): + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + sequential_targets = self.sequential_targets + if sequential_targets is None: + sequential_targets = get_no_split_params(model) + layers = get_layers(sequential_targets, model) + self._num_layers = len(layers) + + for name, module in model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_hook(post_hook)) + + if name in layers.keys(): + pre_hook = partial(self.layer_pre_forward, name) + post_hook = partial(self.layer_post_forward, name) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook( + module.register_forward_hook(post_hook, with_kwargs=True) + ) + + @HooksMixin.hook + def target_pre_forward( + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] + ): + if self.true_sequential: + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + @HooksMixin.hook + def target_post_forward( + self, + name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + _output: Tuple[Any, ...], + ): + if not self.true_sequential: + # compress after so output is from uncompressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + @HooksMixin.hook + def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): + logger.info( + f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" + ) + + @HooksMixin.hook + def layer_post_forward( + self, + _name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + output: Tuple[Any, ...], + ): + if not self.true_sequential: + # rerun with (now) compressed weights + with HooksMixin.disable_hooks(): + compressed_output = module(*args, **kwargs) + + error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) + logger.info(f"Mean output error from quantization: {error:.3f}") + + self._layer_index += 1 + return output diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index b168e2534..3f3aa3d02 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,6 +1,5 @@ import operator -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, Tuple import torch from compressed_tensors import get_execution_device @@ -8,135 +7,15 @@ from tqdm import tqdm from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.utils.fsdp.context import ( fix_fsdp_module_name, summon_full_params_context, ) -from llmcompressor.utils.helpers import getattr_chain -from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - get_prunable_layers, - set_layer, -) - -__all__ = ["SequentialLayerCompressor", "LayerCompressor"] - - -class SequentialLayerCompressor(HooksMixin): - """ - Apply a given compression function to a model during the model's calibration - forward pass - - Lifecycle: - - QuantizationModifier.initialize(model) - - SequentialLayerCompressor(compress_fn) - - register_hooks(model) - - model.forward() - - compress_fn(name, target_module, args) - - remove_hooks() - - :param compress_fn: Function to be called on target modules - :param true_sequential: Used to control the granularity of compression updates - through the forward pass. Set to True to use the weight-compressed outputs - of each module, set to False to use the weight-compressed outputs of each - layer (transformer block), defaults to False - """ - - def __init__( - self, - compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], float], - true_sequential: bool = False, - ): - HooksMixin.__init__(self) - self.compress_fn = compress_fn - self.true_sequential = true_sequential - - self._layer_index = 0 - self._num_layers = 0 - - def register_hooks( - self, - model: torch.nn.Module, - sequential_targets: Optional[Union[str, List[str]]] = None, - ): - # find layers (used for printing even if true_sequential=True) - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - if sequential_targets is None: - sequential_targets = get_no_split_params(model) - layers = get_layers(sequential_targets, model) - self._num_layers = len(layers) - - for name, module in model.named_modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - pre_hook = partial(self.target_pre_forward, name) - post_hook = partial(self.target_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) - self.register_hook(module.register_forward_hook(post_hook)) - - if name in layers.keys(): - pre_hook = partial(self.layer_pre_forward, name) - post_hook = partial(self.layer_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) - self.register_hook( - module.register_forward_hook(post_hook, with_kwargs=True) - ) - - @HooksMixin.hook - def target_pre_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] - ): - if self.true_sequential: - # compress first so output is from compressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_fn(name, module, args) - comp_logger.set_loss(loss) - - @HooksMixin.hook - def target_post_forward( - self, - name: str, - module: torch.nn.Module, - args: Tuple[Any, ...], - _output: Tuple[Any, ...], - ): - if not self.true_sequential: - # compress after so output is from uncompressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_fn(name, module, args) - comp_logger.set_loss(loss) - - @HooksMixin.hook - def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): - logger.info( - f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" - ) - - @HooksMixin.hook - def layer_post_forward( - self, - name: str, - module: torch.nn.Module, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - output: Tuple[Any, ...], - ): - if not self.true_sequential: - # rerun with (now) compressed weights - with HooksMixin.disable_hooks(): - compressed_output = module(*args, **kwargs) - - error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) - logger.info(f"Mean output error from quantization: {error:.3f}") +from llmcompressor.utils.pytorch.module import get_prunable_layers, set_layer - self._layer_index += 1 - return output +__all__ = ["LayerCompressor"] class LayerCompressor: From 7fbf8b193f193047a8d951fb4552a41799c327e1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 24 Oct 2024 03:56:21 +0000 Subject: [PATCH 39/59] docstrings --- src/llmcompressor/modifiers/utils/hooks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index d65e41d01..f7242124c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -24,6 +24,11 @@ class HooksMixin(BaseModel): Modifiers which implement hooks should use the @HooksMixin.hook decorator Modifiers must pass registered hooks handles to self.register_hook() and must remove hooks when finished using self.remove_hooks() + + Lifecycle: + - Modifier.register_hooks(model) + - model.forward() + - Modifier.remove_hooks() """ _HOOKS_DISABLED: ClassVar[bool] = False @@ -75,11 +80,10 @@ class LayerCompressorMixin(HooksMixin): Lifecycle: - QuantizationModifier.initialize(model) - - SequentialLayerCompressor(compress_fn) - - register_hooks(model) + - Modifier.register_hooks(model) - model.forward() - compress_fn(name, target_module, args) - - remove_hooks() + - Modifier.remove_hooks() :ivar true_sequential: Used to control the granularity of compression updates through the forward pass. Set to True to use the weight-compressed outputs From 3d3af2ad0d7dd3b4d63176f9867aeb97f4a8cafd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 24 Oct 2024 04:49:55 +0000 Subject: [PATCH 40/59] add back hessian hook to support bs1 --- .../modifiers/quantization/gptq/base.py | 76 ++++++++++----- .../quantization/gptq/utils/gptq_quantize.py | 28 +++++- src/llmcompressor/utils/fsdp/helpers.py | 93 +++++++++++++++++++ 3 files changed, 170 insertions(+), 27 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 177db29e4..ed9e80269 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -17,6 +17,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( + add_batch, quantize_weight, ) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier @@ -25,6 +26,7 @@ from llmcompressor.transformers.finetune.data.data_helpers import ( create_single_batch_dataloader, ) +from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -200,15 +202,17 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # add hooks to targets and layers - # after lifecycle refactor, move this to pre_batch - self.register_hooks(state.model) + # trigger hessian hooks + self.register_hessians(state.model) + with calibration_forward_context(state.model): + run_calibration_forward(state.model, state.data.calib, mask_padding=True) + self.remove_hooks() - # apply calibration and trigger hooks - self.calibration_forward(state.model, state.data.calib) + self.register_hooks(state.model) + state.model(**state.model.dummy_inputs) + self.remove_hooks() # freeze quantization - # after lifecycle refactor, move this to post_batch state.model.apply(freeze_module_quantization) return True @@ -222,9 +226,31 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) - self.remove_hooks() - return True + + def hessian_hook(self, module, args): + # onload and offload + module.gptq_hessian = add_batch( + module.gptq_hessian.to(args[0].device), + module.gptq_hessian_samples, + module, + args[0] + ).to("cpu") + module.gptq_hessian_samples += 1 + + def register_hessians(self, model: torch.nn.Module): + for module in model.modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + num_columns = module.weight.shape[1] + + # hessian starts offloaded + module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") + module.gptq_hessian_samples = 0 + + self.register_hook(module.register_forward_pre_hook(self.hessian_hook)) + + + def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader @@ -261,24 +287,26 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - with align_module(module): - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - module.weight.data, - inp, - quant_args, - blocksize=self.block_size, - percdamp=self.dampening_frac, - module_class=type(module), - ) + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + module.gptq_hessian.data.to(module.weight.device), + quant_args, + blocksize=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + ) + + delattr(module, "gptq_hessian") + delattr(module, "gptq_hessian_samples") - # FUTURE: Implement learning rate modification to weight update + # FUTURE: Implement learning rate modification to weight update - if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", quantized_weight) - update_parameter_data(module, quantized_weight, "weight") - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") + if is_module_offloaded(module): + update_prefix_dict(self.layer, "weight", quantized_weight) + update_parameter_data(module, quantized_weight, "weight") + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") return loss diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 022252b0a..1ee435d4d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -17,6 +17,28 @@ GPTQ_PRECISION = torch.float32 +def add_batch(H: torch.Tensor, nsamples: int , module: torch.nn.Module, inp: torch.Tensor): + """ + Add a batch of layer input and output data to the Hessian calculation + """ + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(module, torch.nn.Linear) or isinstance( + module, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + H *= nsamples / (nsamples + tmp) + nsamples += tmp + inp = inp.to(dtype=H.dtype) + inp = math.sqrt(2 / nsamples) * inp + H += inp.matmul(inp.t()) + + return H + + def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs @@ -81,7 +103,7 @@ def compute_scale_zero_point( def quantize_weight( weight: torch.Tensor, - inp: torch.Tensor, + H: torch.Tensor, #inp: torch.Tensor, quant_args: QuantizationArgs, blocksize: int = 128, percdamp: float = 0.01, @@ -91,7 +113,7 @@ def quantize_weight( Quantize a module weight according to the GPTQ algorithm :param weight: weight being quantized - :param inp: module inputs used to calculate hessian + # :param inp: module inputs used to calculate hessian :param quant_args: quantization arguments used to find quantization parameters :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal @@ -104,7 +126,7 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() - H = compute_hessian(inp, module_class, device=weight.device) + #H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype if module_class == torch.nn.Conv2d: diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 8cc0f5405..80ef733f1 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,9 +1,12 @@ +import contextlib import operator from pathlib import Path from typing import Optional from loguru import logger +from llmcompressor.utils.helpers import getattr_chain + try: from torch.distributed.fsdp import ( FullStateDictConfig, @@ -179,3 +182,93 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: parent = operator.attrgetter(parent_name)(model) return parent + +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module + has a AlignDevicesHook attached with offloading enabled + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. + """ + from accelerate.hooks import AlignDevicesHook + + return ( + hasattr(module, "_hf_hook") and + isinstance(module._hf_hook, AlignDevicesHook) and + module._hf_hook.offload + ) + +@contextlib.contextmanager +def align_module( + module: torch.nn.Module, + execution_device: Optional[torch.device] = None, + args = tuple(), kwargs = dict() +): + """ + Move a module's parameters to the execution device + :param module: module with parameters to align + :param execution_device: if provided, overrides module execution device + within the context + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = original_device + + module._hf_hook.pre_forward(module, *args, **kwargs) + yield + module._hf_hook.post_forward(module, None) + + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + setattr(module, name, param.to(execution_device)) + + yield + + for name, param_device in module.named_parameters: + setattr(module, name, param.to(param_device)) + + else: + yield + + +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, + init_device: Optional[torch.device] = torch.device("cpu"), +): + """ + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param init_device: offload device for newly registered parameters + """ + param = getattr(module, name) + param.data = data + + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = prefix_dict[key].device if key in prefix_dict else init_device + prefix_dict[key] = data.to(device=offload_device) + + +def register_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, + offload_device: Optional[torch.device] = torch.device("cpu"), +): + module.register_parameter(name, torch.nn.Parameter(data)) + update_offload_parameter(module, name, data, offload_device) \ No newline at end of file From b3021ab9e8d30aeb62786a11e464fcbb1ec5898f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 16:11:33 +0000 Subject: [PATCH 41/59] wip --- .../modifiers/quantization/gptq/base.py | 30 +------- src/llmcompressor/modifiers/utils/hooks.py | 69 +++++++++++++++++-- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ed9e80269..d18e74249 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -203,13 +203,11 @@ def on_initialize(self, state: "State", **kwargs) -> bool: raise ValueError("To use the GPTQModifier, quantization must be enabled.") # trigger hessian hooks - self.register_hessians(state.model) + self.register_hooks(state.model) with calibration_forward_context(state.model): run_calibration_forward(state.model, state.data.calib, mask_padding=True) - self.remove_hooks() - self.register_hooks(state.model) - state.model(**state.model.dummy_inputs) + #state.model(**state.model.dummy_inputs) self.remove_hooks() # freeze quantization @@ -227,30 +225,6 @@ def on_finalize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.finalize(state, **kwargs) return True - - def hessian_hook(self, module, args): - # onload and offload - module.gptq_hessian = add_batch( - module.gptq_hessian.to(args[0].device), - module.gptq_hessian_samples, - module, - args[0] - ).to("cpu") - module.gptq_hessian_samples += 1 - - def register_hessians(self, model: torch.nn.Module): - for module in model.modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - num_columns = module.weight.shape[1] - - # hessian starts offloaded - module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") - module.gptq_hessian_samples = 0 - - self.register_hook(module.register_forward_pre_hook(self.hessian_hook)) - - - def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index f7242124c..da1094872 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,13 +1,15 @@ import contextlib from abc import abstractmethod from functools import partial -from typing import Any, Callable, ClassVar, Dict, List, Tuple +from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple import torch from loguru import logger from pydantic import BaseModel from torch.utils.hooks import RemovableHandle +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params @@ -100,6 +102,9 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 + _pre_active: Set[torch.nn.Module] = set() + _module_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] + _module_outputs: List[Tuple[Any, ...]] = [] @abstractmethod def compress_module( @@ -125,7 +130,7 @@ def register_hooks(self, model: torch.nn.Module): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) if name in layers.keys(): @@ -138,22 +143,74 @@ def register_hooks(self, model: torch.nn.Module): @HooksMixin.hook def target_pre_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): - if self.true_sequential: - # compress first so output is from compressed weights + if module in self._pre_active: + return + + if not hasattr(module, "gptq_hessian"): + print("init hessian") + num_columns = module.weight.shape[1] + + # hessian starts offloaded + module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") + module.gptq_hessian_samples = 0 + + print("add to hessian") + # onload and offload + module.gptq_hessian = add_batch( + module.gptq_hessian.to(args[0].device), + module.gptq_hessian_samples, + module, + args[0] + ).to("cpu") + module.gptq_hessian_samples += 1 + self._module_inputs.append((args, kwargs)) + + if module.gptq_hessian_samples >= 2: + print("compress") with CompressionLogger(module) as comp_logger: loss = self.compress_module(name, module, args) comp_logger.set_loss(loss) + self._pre_active.add(module) + for args, kwargs in self._module_inputs: + try: + module(*args, **kwargs) + except EarlyStopException: + pass + + raise EarlyStopException(torch.Tensor([]), None) + @HooksMixin.hook def target_post_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], - _output: Tuple[Any, ...], + output: Tuple[Any, ...], ): + print("target_post_forward") + return + # accumulate + self._module_outputs.append(output) + + if len(self._module_outputs) == 2: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + ret = self._module_outputs + self._module_outputs = [] + + return ret + + if self.true_sequential: + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + if not self.true_sequential: # compress after so output is from uncompressed weights with CompressionLogger(module) as comp_logger: From 8508b633f14d03e92f762a1fc98818b09ffefd98 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 18:37:42 +0000 Subject: [PATCH 42/59] accumulate --- .../modifiers/quantization/gptq/base.py | 4 +- .../quantization/gptq/utils/gptq_quantize.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 72 +++++++++++-------- 3 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index d18e74249..68d603a4b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -202,10 +202,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # trigger hessian hooks self.register_hooks(state.model) - with calibration_forward_context(state.model): - run_calibration_forward(state.model, state.data.calib, mask_padding=True) + self.calibration_forward(state.model, state.data.calib) #state.model(**state.model.dummy_inputs) self.remove_hooks() diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 1ee435d4d..4365ce3d3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -36,7 +36,7 @@ def add_batch(H: torch.Tensor, nsamples: int , module: torch.nn.Module, inp: tor inp = math.sqrt(2 / nsamples) * inp H += inp.matmul(inp.t()) - return H + return H, nsamples def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index da1094872..83f50d6b1 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,7 +1,7 @@ import contextlib from abc import abstractmethod from functools import partial -from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple +from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple, Union import torch from loguru import logger @@ -104,7 +104,10 @@ class LayerCompressorMixin(HooksMixin): _num_layers = 0 _pre_active: Set[torch.nn.Module] = set() _module_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] - _module_outputs: List[Tuple[Any, ...]] = [] + _module_outputs: Union[List[Tuple[Any, ...]], torch.Tensor] = [] + + _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] + _layer_outputs: List[Tuple[Any, ...]] = [] @abstractmethod def compress_module( @@ -136,7 +139,7 @@ def register_hooks(self, model: torch.nn.Module): if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook( module.register_forward_hook(post_hook, with_kwargs=True) ) @@ -145,42 +148,42 @@ def register_hooks(self, model: torch.nn.Module): def target_pre_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): - if module in self._pre_active: - return + input = args[0] + # compute hessian if not hasattr(module, "gptq_hessian"): - print("init hessian") - num_columns = module.weight.shape[1] - # hessian starts offloaded - module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") + num_columns = module.weight.shape[1] + module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device=input.device) module.gptq_hessian_samples = 0 - print("add to hessian") - # onload and offload - module.gptq_hessian = add_batch( - module.gptq_hessian.to(args[0].device), + module.gptq_hessian, module.gptq_hessian_samples = add_batch( + module.gptq_hessian, module.gptq_hessian_samples, module, - args[0] - ).to("cpu") - module.gptq_hessian_samples += 1 - self._module_inputs.append((args, kwargs)) - + input + ) + if module.gptq_hessian_samples >= 2: - print("compress") - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + # if true, compress + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) - self._pre_active.add(module) - for args, kwargs in self._module_inputs: - try: - module(*args, **kwargs) - except EarlyStopException: - pass + else: + raise EarlyStopException(torch.Tensor([]), None) - raise EarlyStopException(torch.Tensor([]), None) + # forward with individuals + forward_call = (module._slow_forward if torch._C._get_tracing_state() else module.forward) + self._module_outputs = [ + forward_call(input[batch_index: batch_index + 1]) + for batch_index in range(input.shape[0]) + ] + + self._module_outputs = torch.concat(self._module_outputs) + + return (input[0:1], *args[1:]), kwargs @HooksMixin.hook def target_post_forward( @@ -191,7 +194,11 @@ def target_post_forward( output: Tuple[Any, ...], ): print("target_post_forward") - return + + ret = self._module_outputs + self._module_outputs = [] + return ret + # accumulate self._module_outputs.append(output) @@ -218,11 +225,14 @@ def target_post_forward( comp_logger.set_loss(loss) @HooksMixin.hook - def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): + def layer_pre_forward(self, _name: str, layer: torch.nn.Module, _args: Any, kwargs): logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) + + + @HooksMixin.hook def layer_post_forward( self, From 3ff271d87fec32b995c5d76d409abefd3712c388 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 21:16:21 +0000 Subject: [PATCH 43/59] virtualize batches for layers --- src/llmcompressor/modifiers/utils/hooks.py | 98 +++++++++++++++------- 1 file changed, 68 insertions(+), 30 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 83f50d6b1..d5331f640 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -7,6 +7,7 @@ from loguru import logger from pydantic import BaseModel from torch.utils.hooks import RemovableHandle +from collections import defaultdict from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException @@ -42,7 +43,7 @@ def wrapped(*args, **kwargs): if cls._HOOKS_DISABLED: return - func(*args, **kwargs) + return func(*args, **kwargs) return wrapped @@ -103,8 +104,8 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 _pre_active: Set[torch.nn.Module] = set() - _module_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] - _module_outputs: Union[List[Tuple[Any, ...]], torch.Tensor] = [] + _module_inputs: Dict[torch.nn.Module, List[Tuple[Tuple[Any, ...], Dict[str, Any]]]] = defaultdict(lambda: []) + _module_outputs: Dict[torch.nn.Module, Union[List[Tuple[Any, ...]], torch.Tensor]] = defaultdict(lambda: []) _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] _layer_outputs: List[Tuple[Any, ...]] = [] @@ -143,6 +144,7 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook( module.register_forward_hook(post_hook, with_kwargs=True) ) + @HooksMixin.hook def target_pre_forward( @@ -152,11 +154,11 @@ def target_pre_forward( # compute hessian if not hasattr(module, "gptq_hessian"): - num_columns = module.weight.shape[1] module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device=input.device) module.gptq_hessian_samples = 0 + print(f"{name} adding {input.size(0)} samples") module.gptq_hessian, module.gptq_hessian_samples = add_batch( module.gptq_hessian, module.gptq_hessian_samples, @@ -164,26 +166,6 @@ def target_pre_forward( input ) - if module.gptq_hessian_samples >= 2: - # if true, compress - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - else: - raise EarlyStopException(torch.Tensor([]), None) - - # forward with individuals - forward_call = (module._slow_forward if torch._C._get_tracing_state() else module.forward) - self._module_outputs = [ - forward_call(input[batch_index: batch_index + 1]) - for batch_index in range(input.shape[0]) - ] - - self._module_outputs = torch.concat(self._module_outputs) - - return (input[0:1], *args[1:]), kwargs @HooksMixin.hook def target_post_forward( @@ -193,10 +175,21 @@ def target_post_forward( args: Tuple[Any, ...], output: Tuple[Any, ...], ): - print("target_post_forward") + print(f"post {name}") - ret = self._module_outputs - self._module_outputs = [] + if module.gptq_hessian_samples >= 512: + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + """ + breakpoint() + ret = torch.concat(self._module_outputs) + del self._module_inputs[module] + del self._module_outputs[module] return ret # accumulate @@ -223,25 +216,70 @@ def target_post_forward( with CompressionLogger(module) as comp_logger: loss = self.compress_module(name, module, args) comp_logger.set_loss(loss) + """ @HooksMixin.hook - def layer_pre_forward(self, _name: str, layer: torch.nn.Module, _args: Any, kwargs): + def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) - + input = args[0] + + if not self.true_sequential: + self._module_inputs[layer] += [ + input[batch_index: batch_index + 1] + for batch_index in range(input.shape[0]) + ] + + # forward with individuals (might not be necessary) + forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) + self._module_outputs[layer] = [] + for batch_index in range(input.size(0) - 1): + print("layer forward") + output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) + self._module_outputs[layer].append(output) + pass + + # last sample can be passed normally + print("last layer forward") + + return (input[-1:], *args[1:]), kwargs @HooksMixin.hook def layer_post_forward( self, - _name: str, + name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any], output: Tuple[Any, ...], ): + print(f"post {name}") + breakpoint() + + # capture last sample + self._module_outputs[module].append(output) + + # batch outputs + outputs = self._module_outputs[module] + batched_outputs = tuple( + torch.concat(tuple( + outputs[sample_index][output_index] + for sample_index in range(len(outputs)) + )) + for output_index in range(len(outputs[0])) + ) + del self._module_outputs[module] + + if not self.true_sequential: + pass # run again + + del self._module_inputs[module] + + return batched_outputs + if not self.true_sequential: # rerun with (now) compressed weights with HooksMixin.disable_hooks(): From d6c6dc339381cf5eb893e6134604b12a25fc6127 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 22:02:10 +0000 Subject: [PATCH 44/59] maybe works, but padding is wrong --- .../modifiers/quantization/gptq/base.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 37 ++++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 68d603a4b..767664640 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -234,7 +234,7 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - dataloader = create_single_batch_dataloader(dataloader.dataset) + #dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index d5331f640..322f7b787 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -177,7 +177,7 @@ def target_post_forward( ): print(f"post {name}") - if module.gptq_hessian_samples >= 512: + if module.gptq_hessian_samples >= 20: # compress print(f"compressing {name}") if True: #self.true_sequential: @@ -232,26 +232,27 @@ def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs for batch_index in range(input.shape[0]) ] - # forward with individuals (might not be necessary) - forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) - self._module_outputs[layer] = [] - for batch_index in range(input.size(0) - 1): - print("layer forward") - output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) - self._module_outputs[layer].append(output) - pass - # last sample can be passed normally - print("last layer forward") + if len(self._module_outputs[layer]) >= 20 - 1: + # last sample can be passed normally + print("last layer forward") + return (input[-1:], *args[1:]), kwargs + + else: + forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) + for batch_index in range(input.size(0)): + print("layer forward") + output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) + self._module_outputs[layer].append(output) - return (input[-1:], *args[1:]), kwargs + raise EarlyStopException(torch.tensor([]), None) @HooksMixin.hook def layer_post_forward( self, name: str, - module: torch.nn.Module, + layer: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any], output: Tuple[Any, ...], @@ -260,10 +261,10 @@ def layer_post_forward( breakpoint() # capture last sample - self._module_outputs[module].append(output) + self._module_outputs[layer].append(output) # batch outputs - outputs = self._module_outputs[module] + outputs = self._module_outputs[layer] batched_outputs = tuple( torch.concat(tuple( outputs[sample_index][output_index] @@ -271,19 +272,19 @@ def layer_post_forward( )) for output_index in range(len(outputs[0])) ) - del self._module_outputs[module] + del self._module_outputs[layer] if not self.true_sequential: pass # run again - del self._module_inputs[module] + del self._module_inputs[layer] return batched_outputs if not self.true_sequential: # rerun with (now) compressed weights with HooksMixin.disable_hooks(): - compressed_output = module(*args, **kwargs) + compressed_output = layer(*args, **kwargs) error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) logger.info(f"Mean output error from quantization: {error:.3f}") From 400fa0864875b15bbf0ff15d1a85db40b3e9c656 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 14:54:42 +0000 Subject: [PATCH 45/59] WIP --- .../modifiers/quantization/gptq/base.py | 42 ++-- .../quantization/gptq/utils/gptq_quantize.py | 19 +- src/llmcompressor/modifiers/utils/hooks.py | 81 +++----- src/llmcompressor/utils/fsdp/helpers.py | 188 +++++++++++++++--- 4 files changed, 218 insertions(+), 112 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 767664640..7e6e5556d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -259,26 +259,34 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - module.weight.data, - module.gptq_hessian.data.to(module.weight.device), - quant_args, - blocksize=self.block_size, - percdamp=self.dampening_frac, - module_class=type(module), - ) + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + inp, + quant_args, + blocksize=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + original_weight=module.original_weight.data, + ) + + delattr(module, "gptq_hessian") + delattr(module, "gptq_hessian_samples") - delattr(module, "gptq_hessian") - delattr(module, "gptq_hessian_samples") + # FUTURE: Implement learning rate modification to weight update - # FUTURE: Implement learning rate modification to weight update + if is_module_offloaded(module): + update_prefix_dict(self.layer, "weight", quantized_weight) + update_parameter_data(module, quantized_weight, "weight") + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") - if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", quantized_weight) - update_parameter_data(module, quantized_weight, "weight") - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") + if offloaded: + module._hf_hook.post_forward(module, None) return loss diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 4365ce3d3..5f4f0cd22 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,6 +1,6 @@ import math from copy import copy -from typing import Tuple, Union +from typing import Tuple, Union, Optional, Type import torch import transformers @@ -82,7 +82,8 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: def compute_scale_zero_point( - W: torch.Tensor, quant_args: QuantizationArgs + W: torch.Tensor, + quant_args: QuantizationArgs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute the scale and zero point of a module weight @@ -103,17 +104,19 @@ def compute_scale_zero_point( def quantize_weight( weight: torch.Tensor, - H: torch.Tensor, #inp: torch.Tensor, + inp: torch.Tensor, quant_args: QuantizationArgs, blocksize: int = 128, percdamp: float = 0.01, - module_class=torch.nn.Linear, + module_class: Type[torch.nn.Module] = torch.nn.Linear, + original_weight: Optional[torch.Tensor] = None, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm + TODO :param weight: weight being quantized - # :param inp: module inputs used to calculate hessian + :param inp: module inputs used to calculate hessian :param quant_args: quantization arguments used to find quantization parameters :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal @@ -126,7 +129,7 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() - #H = compute_hessian(inp, module_class, device=weight.device) + H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -199,9 +202,9 @@ def quantize_weight( W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): - w = W1[:, i] + w = original_weight[:, i] d = Hinv1[i, i] - q = w.clone() + q = W1[:, i].clone() # quantize column if strategy == QuantizationStrategy.TENSOR: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 322f7b787..414c75bc1 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -11,6 +11,7 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException +from llmcompressor.utils.fsdp.helpers import register_offload_parameter from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params @@ -99,11 +100,9 @@ class LayerCompressorMixin(HooksMixin): true_sequential: bool sequential_targets: bool - # compress_module: Callable[[str, torch.nn.Module, Tuple], float] _layer_index = 0 _num_layers = 0 - _pre_active: Set[torch.nn.Module] = set() _module_inputs: Dict[torch.nn.Module, List[Tuple[Tuple[Any, ...], Dict[str, Any]]]] = defaultdict(lambda: []) _module_outputs: Dict[torch.nn.Module, Union[List[Tuple[Any, ...]], torch.Tensor]] = defaultdict(lambda: []) @@ -137,6 +136,10 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) + #register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? + register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? + register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? + if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) @@ -166,6 +169,14 @@ def target_pre_forward( input ) + if self.true_sequential: + if module.gptq_hessian_samples >= 20: + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def target_post_forward( @@ -176,14 +187,14 @@ def target_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + if not self.true_sequential: + if module.gptq_hessian_samples >= 20: + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) """ breakpoint() @@ -223,30 +234,7 @@ def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) - - input = args[0] - - if not self.true_sequential: - self._module_inputs[layer] += [ - input[batch_index: batch_index + 1] - for batch_index in range(input.shape[0]) - ] - - - if len(self._module_outputs[layer]) >= 20 - 1: - # last sample can be passed normally - print("last layer forward") - return (input[-1:], *args[1:]), kwargs - - else: - forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) - for batch_index in range(input.size(0)): - print("layer forward") - output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) - self._module_outputs[layer].append(output) - - raise EarlyStopException(torch.tensor([]), None) - + @HooksMixin.hook def layer_post_forward( @@ -258,30 +246,9 @@ def layer_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - breakpoint() - - # capture last sample - self._module_outputs[layer].append(output) - - # batch outputs - outputs = self._module_outputs[layer] - batched_outputs = tuple( - torch.concat(tuple( - outputs[sample_index][output_index] - for sample_index in range(len(outputs)) - )) - for output_index in range(len(outputs[0])) - ) - del self._module_outputs[layer] - - if not self.true_sequential: - pass # run again - - del self._module_inputs[layer] - - return batched_outputs - if not self.true_sequential: + + if False and not self.true_sequential: # only print # rerun with (now) compressed weights with HooksMixin.disable_hooks(): compressed_output = layer(*args, **kwargs) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 80ef733f1..e5ecda8c7 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,7 +1,9 @@ import contextlib +from functools import wraps import operator from pathlib import Path from typing import Optional +import warnings from loguru import logger @@ -23,6 +25,16 @@ from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe from llmcompressor.utils.pytorch import set_layer +try: + from accelerate.hooks import AlignDevicesHook + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + _has_accelerate = True +except ImportError: + _has_accelerate = False + AlignDevicesHook = None + OffloadedWeightsLoader = None + PrefixedDataset = None + __all__ = [ "is_fsdp_model", "maybe_get_wrapped", @@ -183,32 +195,150 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: return parent +# upstream candidate def has_offloaded_params(module: torch.nn.Module) -> bool: """ Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with offloading enabled + Args: module (`torch.nn.Module`): The module to check for an offload hook. + Returns: bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise. """ - from accelerate.hooks import AlignDevicesHook - return ( hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload ) -@contextlib.contextmanager -def align_module( + +# depreciation candidate +@wraps(has_offloaded_params) +def is_module_offloaded(module: torch.nn.Module) -> bool: + if not _has_accelerate: + return False + + return has_offloaded_params(module) + + +# depreciation candidate +def get_execution_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is loaded onto during forward pass + """ + if is_module_offloaded(module): + return module._hf_hook.execution_device + device = next(module.parameters()).device + + # offload only gets set for leaf modules, fallback to checking for device type + if device.type == "meta": + return module._hf_hook.execution_device + + return device + + +# upstream candidate +def _infer_offload_device(module: torch.nn.Module) -> torch.device: + if not has_offloaded_params(module): + raise ValueError("Cannot infer offload device from non-offloaded module") + + first_key = next(module._hf_hook.weights_map.keys(), None) + if first_key is None: + raise ValueError("Cannot infer offload device from empty weights map") + + prefix_dataset = module._hf_hook.weights_map.dataset + return prefix_dataset[first_key].device + +# depreciation candidate +def get_offloaded_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is offloaded to onto after forward pass + """ + return _infer_offload_device(module) + + +# depreciation candidate +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): + """ + Updates the offloaded state dict for a given module. Parameter named key is replaced + by data. This is neccesary because parameter updates for offloaded modules do not + persist automatically between loads. This function only affects the offloaded + state dict and not the current state of the loaded module. + + :param module: module containing the parameter to update + :param key: name of parameter to update + :param data: tensor to update parameter with in the offloaded state dict + """ + if not is_module_offloaded(module): + raise ValueError("Prefix dict is only applicable to offloaded modules") + prefix_dict = module._hf_hook.weights_map + prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + +# upstream candidate? +def update_offload_parameter( module: torch.nn.Module, - execution_device: Optional[torch.device] = None, - args = tuple(), kwargs = dict() + name: str, + data: torch.Tensor, + offload_device: Optional[torch.device] = None, ): + """ + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: offload device for newly registered parameters + """ + if data.device == "meta": + raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") + + param = getattr(module, name) + if param.data.dtype != data.dtype: + warnings.warn("TODO") + param.data.copy_(data) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = ( + prefix_dict[key].device if key in prefix_dict + else offload_device if offload_device is not None + else _infer_offload_device(module) + ) + prefix_dict[key] = data.to(device=offload_device) + + if isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + else: + raise NotImplementedError() + +# depreciation candidate +def update_parameter_data( + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str +): + param = getattr(module, param_name) + new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) + update_offload_parameter(module, param_name, new_param_data) + + +# upstream candidate +@contextlib.contextmanager +def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ Move a module's parameters to the execution device + :param module: module with parameters to align :param execution_device: if provided, overrides module execution device within the context @@ -218,7 +348,7 @@ def align_module( original_device = module._hf_hook.execution_device module._hf_hook.execution_device = original_device - module._hf_hook.pre_forward(module, *args, **kwargs) + module._hf_hook.pre_forward(module) yield module._hf_hook.post_forward(module, None) @@ -240,35 +370,33 @@ def align_module( yield -def update_offload_parameter( +@contextlib.contextmanager +def modify_offload_module( module: torch.nn.Module, - name: str, - data: torch.Tensor, - init_device: Optional[torch.device] = torch.device("cpu"), + execution_device: Optional[torch.device] = None, + offload_device: Optional[torch.device] = None, ): - """ - :param module: module containing the parameter to update - :param name: name of module parameter to update - :param data: tensor to update parameter with - :param init_device: offload device for newly registered parameters - """ - param = getattr(module, name) - param.data = data - - prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) - if prefix_dict is not None: - prefix = module._hf_hook.weights_map.prefix - key = f"{prefix}{name}" + with align_module(module, execution_device): + yield - offload_device = prefix_dict[key].device if key in prefix_dict else init_device - prefix_dict[key] = data.to(device=offload_device) + # there is little performance gain from checking if a parameter's data + # has been modified before copying since the new data must be copied + # to the offload device anyways; just update all module parameters + for name, param in module.named_parameters(): + update_offload_parameter(module, name, param.data, offload_device) +# upstream candidate? def register_offload_parameter( module: torch.nn.Module, name: str, - data: torch.Tensor, - offload_device: Optional[torch.device] = torch.device("cpu"), + parameter: torch.nn.Parameter, + offload_device: Optional[torch.device] = None, ): - module.register_parameter(name, torch.nn.Parameter(data)) - update_offload_parameter(module, name, data, offload_device) \ No newline at end of file + module.register_parameter(name, parameter) + update_offload_parameter(module, name, parameter.data, offload_device) + + +# upstream candidate? +def deregister_offload_parameter(): + raise NotImplementedError() \ No newline at end of file From 03515f0ecfab91061e2b60fb338ed1b1a898533f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:52:12 +0000 Subject: [PATCH 46/59] remove hessian --- .../modifiers/quantization/gptq/base.py | 18 +--- src/llmcompressor/modifiers/utils/hooks.py | 85 ++----------------- 2 files changed, 12 insertions(+), 91 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7e6e5556d..119179666 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -26,7 +26,7 @@ from llmcompressor.transformers.finetune.data.data_helpers import ( create_single_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter +from llmcompressor.utils.fsdp.helpers import has_offloaded_params from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -259,10 +259,7 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - + with align_module(module): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, @@ -273,21 +270,14 @@ def compress_module( original_weight=module.original_weight.data, ) - delattr(module, "gptq_hessian") - delattr(module, "gptq_hessian_samples") - - # FUTURE: Implement learning rate modification to weight update + #weight_update_acc = module.weight_update_acc.data + quantized_weight + #update_parameter_data(module, quantized_weight, "weight") - if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", quantized_weight) update_parameter_data(module, quantized_weight, "weight") update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") - if offloaded: - module._hf_hook.post_forward(module, None) - return loss def _build_quant_modifier(self): diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 414c75bc1..3bda3da2c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -137,8 +137,8 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_hook(post_hook)) #register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? - register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? - register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? + #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? + #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) @@ -153,30 +153,12 @@ def register_hooks(self, model: torch.nn.Module): def target_pre_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): - input = args[0] - - # compute hessian - if not hasattr(module, "gptq_hessian"): - num_columns = module.weight.shape[1] - module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device=input.device) - module.gptq_hessian_samples = 0 - - print(f"{name} adding {input.size(0)} samples") - module.gptq_hessian, module.gptq_hessian_samples = add_batch( - module.gptq_hessian, - module.gptq_hessian_samples, - module, - input - ) - - if self.true_sequential: - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def target_post_forward( @@ -187,47 +169,6 @@ def target_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - if not self.true_sequential: - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - """ - breakpoint() - ret = torch.concat(self._module_outputs) - del self._module_inputs[module] - del self._module_outputs[module] - return ret - - # accumulate - self._module_outputs.append(output) - - if len(self._module_outputs) == 2: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - ret = self._module_outputs - self._module_outputs = [] - - return ret - - if self.true_sequential: - # compress first so output is from compressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - if not self.true_sequential: - # compress after so output is from uncompressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - """ @HooksMixin.hook def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): @@ -246,15 +187,5 @@ def layer_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - - - if False and not self.true_sequential: # only print - # rerun with (now) compressed weights - with HooksMixin.disable_hooks(): - compressed_output = layer(*args, **kwargs) - - error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) - logger.info(f"Mean output error from quantization: {error:.3f}") - self._layer_index += 1 return output From 6e37f649e63c81e402389c5a009b74bf70be6eb9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:54:00 +0000 Subject: [PATCH 47/59] allocated original weight --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 4 ++-- src/llmcompressor/modifiers/utils/hooks.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 119179666..6026fe530 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -267,7 +267,7 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - original_weight=module.original_weight.data, + weight_original=module.weight_original.data, ) #weight_update_acc = module.weight_update_acc.data + quantized_weight diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 5f4f0cd22..617572391 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -109,7 +109,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class: Type[torch.nn.Module] = torch.nn.Linear, - original_weight: Optional[torch.Tensor] = None, + weight_original: Optional[torch.Tensor] = None, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm @@ -202,7 +202,7 @@ def quantize_weight( W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): - w = original_weight[:, i] + w = weight_original[:, i] d = Hinv1[i, i] q = W1[:, i].clone() diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 3bda3da2c..f3d674538 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -136,7 +136,7 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - #register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? + register_offload_parameter(module, "weight_original", module.weight.clone()) # TODO: better name? #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? From 09dae14c7e661fecdef2f8b90cd01e1266733c03 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:56:18 +0000 Subject: [PATCH 48/59] proper clone --- src/llmcompressor/modifiers/utils/hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index f3d674538..7c8858cbc 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -136,7 +136,8 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - register_offload_parameter(module, "weight_original", module.weight.clone()) # TODO: better name? + breakpoint() + register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? From 944601e06a2c8279e1bc4e80f4157c4a1405d225 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:56:43 +0000 Subject: [PATCH 49/59] remove breakpoint --- src/llmcompressor/modifiers/utils/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 7c8858cbc..05f7dfa3c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -136,7 +136,6 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - breakpoint() register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? From adbcee8ccc407d4d03c20a97b5fea467831611a1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:43:49 +0000 Subject: [PATCH 50/59] naive_update option --- .../modifiers/quantization/gptq/base.py | 72 +++++++++++++++---- src/llmcompressor/modifiers/utils/hooks.py | 13 +++- .../finetune/data/data_helpers.py | 8 ++- src/llmcompressor/utils/fsdp/helpers.py | 11 +-- 4 files changed, 79 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 6026fe530..c26672422 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import math from compressed_tensors.quantization import ( QuantizationScheme, freeze_module_quantization, @@ -24,14 +25,18 @@ from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.transformers.finetune.data.data_helpers import ( - create_single_batch_dataloader, + create_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import has_offloaded_params +from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter, update_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, getattr_chain, ) +from compressed_tensors.quantization import ( + fake_quantize, +) + from llmcompressor.utils.pytorch.module import qat_active __all__ = ["GPTQModifier"] @@ -75,10 +80,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param true_sequential: Used to control the granularity of compression updates - through the forward pass. Set to True to use the weight-compressed outputs - of each module, set to False to use the weight-compressed outputs of each - layer (transformer block), defaults to False + :param naive_update: TODO :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass @@ -109,7 +111,8 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - true_sequential: bool = False + naive_update: bool = False + batch_size: int = 1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -124,6 +127,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() + _num_batches: int = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -201,6 +205,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + + self._num_batches = math.ceil(len(state.data.calib.dataset) / self.batch_size) self.register_hooks(state.model) self.calibration_forward(state.model, state.data.calib) @@ -222,6 +228,31 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) + for module in state.model.modules(): + with align_module(module): + quant_args = getattr_chain(module, "quantization_scheme.weights", None) + if quant_args is None: + continue + + if self.naive_update: + weight = module.weight_acc / self._num_batches + delattr(module, "weight_acc") + + if self.naive_update: + weight = module.weight + delattr(module, "weight_original") + + scale, zero_point = quant_args.get_observer()(weight) + weight = fake_quantize( + weight, + scale, + zero_point, + quant_args, + ) + update_offload_parameter(module, "weight", weight) + update_offload_parameter(module, "scale", scale) + update_offload_parameter(module, "zero_point", zero_point) + return True def calibration_forward( @@ -234,10 +265,18 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - #dataloader = create_single_batch_dataloader(dataloader.dataset) + dataloader = create_batch_dataloader(dataloader.dataset, batch_size=self.batch_size) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) + def pre_compress_module(self, module: torch.nn.Module): + # TODO: better names? + if self.naive_update: + register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) + + else: + register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) + def compress_module( self, name: str, @@ -267,16 +306,19 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=module.weight_original.data, + weight_original=module.weight_original.data if self.naive_update else module.weight.data ) - #weight_update_acc = module.weight_update_acc.data + quantized_weight - #update_parameter_data(module, quantized_weight, "weight") + if self.naive_update: + module.weight_acc += quantized_weight + update_offload_parameter(module, "weight_acc") + else: + module.weight += (quantized_weight - module.weight) * self._num_batches + update_offload_parameter(module, "weight") - update_parameter_data(module, quantized_weight, "weight") - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") + scale, zero_point = quant_args.get_observer()(module.weight) + update_offload_parameter(module, "scale", scale) + update_offload_parameter(module, "zero_point", zero_point) return loss diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 05f7dfa3c..bf03f6e86 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -109,6 +109,15 @@ class LayerCompressorMixin(HooksMixin): _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] _layer_outputs: List[Tuple[Any, ...]] = [] + @abstractmethod + def pre_compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: + raise NotImplementedError() + @abstractmethod def compress_module( self, @@ -136,9 +145,7 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? - #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? - #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? + self.pre_compress_module(module) if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index cc1c946ac..6f336aa8b 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -13,7 +13,7 @@ LABELS_MASK_VALUE = -100 __all__ = [ - "create_single_batch_dataloader", + "create_batch_dataloader", "format_calibration_data", "get_raw_dataset", "make_dataset_splits", @@ -22,13 +22,15 @@ ] -def create_single_batch_dataloader( +def create_batch_dataloader( dataset: datasets.Dataset, + batch_size: int, ) -> torch.utils.data.DataLoader: """ Create a dataloader whose batch size is equal to the size of the dataset :param dataset: dataset used to generate dataloader + :param batch_size: batch size of new dataloader :return: dataloader """ @@ -49,7 +51,7 @@ def pad_sequences(batch): return torch.utils.data.DataLoader( dataset, - batch_size=len(dataset), + batch_size=batch_size, shuffle=True, collate_fn=pad_sequences, pin_memory=True, diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index e5ecda8c7..2707cbae8 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -284,7 +284,7 @@ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): def update_offload_parameter( module: torch.nn.Module, name: str, - data: torch.Tensor, + data: Optional[torch.Tensor] = None, offload_device: Optional[torch.device] = None, ): """ @@ -297,9 +297,12 @@ def update_offload_parameter( raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") param = getattr(module, name) - if param.data.dtype != data.dtype: - warnings.warn("TODO") - param.data.copy_(data) + if data is None: + data = param.data + else: + if param.data.dtype != data.dtype: + warnings.warn("TODO") + param.data.copy_(data) if has_offloaded_params(module): weights_map = module._hf_hook.weights_map From f4acab20dcb032662f41ec1acbdce4632e94cc99 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:45:29 +0000 Subject: [PATCH 51/59] remove true sequential --- src/llmcompressor/modifiers/utils/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index bf03f6e86..dce60a6a5 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -98,7 +98,6 @@ class LayerCompressorMixin(HooksMixin): :ivar compresss_module: Function to be called on target modules """ - true_sequential: bool sequential_targets: bool _layer_index = 0 From 151f566730645f820b9fcd9dac6c006ef4ed7595 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:48:09 +0000 Subject: [PATCH 52/59] allow update_offload_parameter to not require data --- src/llmcompressor/utils/fsdp/helpers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 2707cbae8..5b60b68e2 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -293,15 +293,14 @@ def update_offload_parameter( :param data: tensor to update parameter with :param offload_device: offload device for newly registered parameters """ - if data.device == "meta": - raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") - param = getattr(module, name) - if data is None: - data = param.data - else: + if data is not None: + if data.device == "meta": + raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") + if param.data.dtype != data.dtype: warnings.warn("TODO") + param.data.copy_(data) if has_offloaded_params(module): @@ -319,7 +318,7 @@ def update_offload_parameter( else offload_device if offload_device is not None else _infer_offload_device(module) ) - prefix_dict[key] = data.to(device=offload_device) + prefix_dict[key] = param.data.to(device=offload_device) if isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() From 76ebc8609c05ced749784f3b12e6df2cbc05eb45 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:49:25 +0000 Subject: [PATCH 53/59] bugfix --- src/llmcompressor/modifiers/quantization/gptq/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c26672422..4a22c6f9b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -111,7 +111,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - naive_update: bool = False + naive_update: bool = True batch_size: int = 1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -250,8 +250,8 @@ def on_finalize(self, state: "State", **kwargs) -> bool: quant_args, ) update_offload_parameter(module, "weight", weight) - update_offload_parameter(module, "scale", scale) - update_offload_parameter(module, "zero_point", zero_point) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) return True @@ -317,8 +317,8 @@ def compress_module( update_offload_parameter(module, "weight") scale, zero_point = quant_args.get_observer()(module.weight) - update_offload_parameter(module, "scale", scale) - update_offload_parameter(module, "zero_point", zero_point) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) return loss From 3480d6b75c0df8c2de4d8c687a18d1cc8ee262d3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:50:59 +0000 Subject: [PATCH 54/59] ba --- src/llmcompressor/modifiers/quantization/gptq/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 4a22c6f9b..c90574b2a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -238,7 +238,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: weight = module.weight_acc / self._num_batches delattr(module, "weight_acc") - if self.naive_update: + else: weight = module.weight delattr(module, "weight_original") @@ -306,7 +306,7 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=module.weight_original.data if self.naive_update else module.weight.data + weight_original=module.weight.data if self.naive_update else module.weight_original.data ) if self.naive_update: From 7c55fc596d14eba258b366fbf7032c7ba5a26bfd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 19:12:30 -0400 Subject: [PATCH 55/59] delete parameter --- .../modifiers/quantization/gptq/base.py | 6 +-- src/llmcompressor/utils/fsdp/helpers.py | 51 +++++++++++++++---- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c90574b2a..92788dce1 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -27,7 +27,7 @@ from llmcompressor.transformers.finetune.data.data_helpers import ( create_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter, update_offload_parameter +from llmcompressor.utils.fsdp.helpers import delete_offload_parameter, register_offload_parameter, update_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -236,11 +236,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.naive_update: weight = module.weight_acc / self._num_batches - delattr(module, "weight_acc") + delete_offload_parameter(module, "weight_acc") else: weight = module.weight - delattr(module, "weight_original") + delete_offload_parameter(module, "weight_original") scale, zero_point = quant_args.get_observer()(weight) weight = fake_quantize( diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 5b60b68e2..e58b4f1c3 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -27,7 +27,7 @@ try: from accelerate.hooks import AlignDevicesHook - from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device _has_accelerate = True except ImportError: _has_accelerate = False @@ -339,16 +339,20 @@ def update_parameter_data( @contextlib.contextmanager def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ - Move a module's parameters to the execution device + Moves a module's parameters to the specified execution device. - :param module: module with parameters to align - :param execution_device: if provided, overrides module execution device - within the context + Args: + module (torch.nn.Module): Module with parameters to align. + execution_device (Optional[torch.device]): If provided, overrides the + module's execution device within the context. + + Yields: + None: Yields control while the module's parameters are aligned to the execution device. """ if has_offloaded_params(module): if execution_device is not None: original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = original_device + module._hf_hook.execution_device = execution_device module._hf_hook.pre_forward(module) yield @@ -361,17 +365,26 @@ def align_module(module: torch.nn.Module, execution_device: Optional[torch.devic devices = {} for name, param in module.named_parameters(): devices[name] = param.device - setattr(module, name, param.to(execution_device)) + set_module_tensor_to_device( + module, + name, + execution_device, + ) yield - for name, param_device in module.named_parameters: - setattr(module, name, param.to(param_device)) + for name, param in module.named_parameters(): + set_module_tensor_to_device( + module, + name, + devices[name], + ) else: yield + @contextlib.contextmanager def modify_offload_module( module: torch.nn.Module, @@ -400,5 +413,21 @@ def register_offload_parameter( # upstream candidate? -def deregister_offload_parameter(): - raise NotImplementedError() \ No newline at end of file +def delete_offload_parameter(module: torch.nn.Module, name: str): + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + prefix = weights_map.prefix + if dataset is not None: + del dataset[f"{prefix}{name}"] + + elif isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + elif weights_map is not None: + raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file From 0a8004b00725451a43b40677ddb0fbbe3268e04c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Oct 2024 15:01:01 -0400 Subject: [PATCH 56/59] sensible generations for small calibration size --- .../modifiers/quantization/gptq/base.py | 57 +++++++++++-------- .../quantization/gptq/utils/gptq_quantize.py | 33 +++-------- .../finetune/data/data_helpers.py | 6 +- 3 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 92788dce1..c1e535215 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -7,20 +7,12 @@ QuantizationScheme, freeze_module_quantization, ) -from compressed_tensors.utils import ( - is_module_offloaded, - update_parameter_data, - update_prefix_dict, -) from loguru import logger from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( - add_batch, - quantize_weight, -) +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -112,7 +104,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): sequential_update: bool = True # DEPRECIATED naive_update: bool = True - batch_size: int = 1 + batch_size: int = -1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -139,6 +131,16 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return True + + @field_validator("naive_update", mode="before") + def validate_naive_update(cls, value: bool) -> bool: + if not value: + raise ValueError( + "`naive_update=False` is not implemented yet, please use " + "`naive_update=True`" + ) + + return True def on_initialize_structure(self, state: State, **kwargs): """ @@ -206,29 +208,23 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + if self.batch_size <= 0: + self.batch_size = len(state.data.calib.dataset) self._num_batches = math.ceil(len(state.data.calib.dataset) / self.batch_size) self.register_hooks(state.model) self.calibration_forward(state.model, state.data.calib) - #state.model(**state.model.dummy_inputs) self.remove_hooks() + self.finish_compression() # freeze quantization state.model.apply(freeze_module_quantization) return True - - def on_finalize(self, state: "State", **kwargs) -> bool: - """ - disable the quantization observers used by the OBCQ algorithm - - :param state: session state storing input model and calibration data - """ - if self._quantization_modifier: - self._quantization_modifier.finalize(state, **kwargs) - - for module in state.model.modules(): + + def finish_compression(self, model: torch.nn.Module): + for module in model.modules(): with align_module(module): quant_args = getattr_chain(module, "quantization_scheme.weights", None) if quant_args is None: @@ -253,6 +249,15 @@ def on_finalize(self, state: "State", **kwargs) -> bool: update_offload_parameter(module, "weight_scale", scale) update_offload_parameter(module, "weight_zero_point", zero_point) + def on_finalize(self, state: "State", **kwargs) -> bool: + """ + disable the quantization observers used by the OBCQ algorithm + + :param state: session state storing input model and calibration data + """ + if self._quantization_modifier: + self._quantization_modifier.finalize(state, **kwargs) + return True def calibration_forward( @@ -265,7 +270,7 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - dataloader = create_batch_dataloader(dataloader.dataset, batch_size=self.batch_size) + dataloader = create_batch_dataloader(dataloader, batch_size=self.batch_size) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) @@ -297,6 +302,7 @@ def compress_module( # Assume that first argument is the input inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") + logger.info(f"Using {inp.size(0)} samples") with align_module(module): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( @@ -306,19 +312,20 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=module.weight.data if self.naive_update else module.weight_original.data + weight_original=None if self.naive_update else module.weight_original.data ) if self.naive_update: module.weight_acc += quantized_weight update_offload_parameter(module, "weight_acc") else: - module.weight += (quantized_weight - module.weight) * self._num_batches + module.weight += (quantized_weight - module.weight) / self._num_batches update_offload_parameter(module, "weight") scale, zero_point = quant_args.get_observer()(module.weight) update_offload_parameter(module, "weight_scale", scale) update_offload_parameter(module, "weight_zero_point", zero_point) + update_offload_parameter(module, "weight_g_idx", g_idx) # NOT SURE IF THIS IS CORRECT return loss diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 617572391..a2354a242 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -17,29 +17,7 @@ GPTQ_PRECISION = torch.float32 -def add_batch(H: torch.Tensor, nsamples: int , module: torch.nn.Module, inp: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(module, torch.nn.Linear) or isinstance( - module, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - H *= nsamples / (nsamples + tmp) - nsamples += tmp - inp = inp.to(dtype=H.dtype) - inp = math.sqrt(2 / nsamples) * inp - H += inp.matmul(inp.t()) - - return H, nsamples - - -def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: +def compute_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs @@ -129,7 +107,8 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() - H = compute_hessian(inp, module_class, device=weight.device) + if weight_original is not None: + raise NotImplementedError() # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -140,6 +119,8 @@ def quantize_weight( num_rows = W.shape[0] num_columns = W.shape[1] + H = compute_hessian(inp, module_class, device=weight.device) + if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( @@ -202,9 +183,9 @@ def quantize_weight( W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): - w = weight_original[:, i] + w = W1[:, i] d = Hinv1[i, i] - q = W1[:, i].clone() + q = w.clone() # quantize column if strategy == QuantizationStrategy.TENSOR: diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 6f336aa8b..92d73edc2 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -23,7 +23,7 @@ def create_batch_dataloader( - dataset: datasets.Dataset, + dataloader: torch.utils.data.DataLoader, batch_size: int, ) -> torch.utils.data.DataLoader: """ @@ -33,6 +33,8 @@ def create_batch_dataloader( :param batch_size: batch size of new dataloader :return: dataloader """ + dataset = dataloader.dataset + sampler = dataloader.sampler.__class__(dataset) def pad_sequences(batch): # extract input_ids and attention_mask from the batch @@ -52,7 +54,7 @@ def pad_sequences(batch): return torch.utils.data.DataLoader( dataset, batch_size=batch_size, - shuffle=True, + sampler=sampler, collate_fn=pad_sequences, pin_memory=True, ) From d234b322df80dc565253384473a18b1c000c7f14 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Oct 2024 15:13:04 -0400 Subject: [PATCH 57/59] remove unnecessary variables --- src/llmcompressor/modifiers/utils/hooks.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index dce60a6a5..44e2cd8a7 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -102,11 +102,6 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 - _module_inputs: Dict[torch.nn.Module, List[Tuple[Tuple[Any, ...], Dict[str, Any]]]] = defaultdict(lambda: []) - _module_outputs: Dict[torch.nn.Module, Union[List[Tuple[Any, ...]], torch.Tensor]] = defaultdict(lambda: []) - - _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] - _layer_outputs: List[Tuple[Any, ...]] = [] @abstractmethod def pre_compress_module( From eeb5c8316400540da4a0010f4ca0658bb0d02c62 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Oct 2024 15:25:05 -0400 Subject: [PATCH 58/59] remove non-naive updating stuff to focus on naive updating --- .../modifiers/quantization/gptq/base.py | 46 +++---------------- src/llmcompressor/modifiers/utils/hooks.py | 3 -- 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c1e535215..471340a26 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -72,7 +72,6 @@ class GPTQModifier(Modifier, LayerCompressorMixin): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param naive_update: TODO :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass @@ -103,7 +102,6 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - naive_update: bool = True batch_size: int = -1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -131,16 +129,6 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return True - - @field_validator("naive_update", mode="before") - def validate_naive_update(cls, value: bool) -> bool: - if not value: - raise ValueError( - "`naive_update=False` is not implemented yet, please use " - "`naive_update=True`" - ) - - return True def on_initialize_structure(self, state: State, **kwargs): """ @@ -216,7 +204,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.calibration_forward(state.model, state.data.calib) self.remove_hooks() - self.finish_compression() + self.finish_compression(state.model) # freeze quantization state.model.apply(freeze_module_quantization) @@ -230,13 +218,8 @@ def finish_compression(self, model: torch.nn.Module): if quant_args is None: continue - if self.naive_update: - weight = module.weight_acc / self._num_batches - delete_offload_parameter(module, "weight_acc") - - else: - weight = module.weight - delete_offload_parameter(module, "weight_original") + weight = module.weight_acc / self._num_batches + delete_offload_parameter(module, "weight_acc") scale, zero_point = quant_args.get_observer()(weight) weight = fake_quantize( @@ -275,12 +258,7 @@ def calibration_forward( run_calibration_forward(model, dataloader, mask_padding=True) def pre_compress_module(self, module: torch.nn.Module): - # TODO: better names? - if self.naive_update: - register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) - - else: - register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) + register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) def compress_module( self, @@ -305,27 +283,17 @@ def compress_module( logger.info(f"Using {inp.size(0)} samples") with align_module(module): - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + loss, quantized_weight, _scale, _zero_point, _g_idx = quantize_weight( module.weight.data, inp, quant_args, blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=None if self.naive_update else module.weight_original.data ) - if self.naive_update: - module.weight_acc += quantized_weight - update_offload_parameter(module, "weight_acc") - else: - module.weight += (quantized_weight - module.weight) / self._num_batches - update_offload_parameter(module, "weight") - - scale, zero_point = quant_args.get_observer()(module.weight) - update_offload_parameter(module, "weight_scale", scale) - update_offload_parameter(module, "weight_zero_point", zero_point) - update_offload_parameter(module, "weight_g_idx", g_idx) # NOT SURE IF THIS IS CORRECT + module.weight_acc += quantized_weight + update_offload_parameter(module, "weight_acc") return loss diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 44e2cd8a7..9dcbac9d3 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -9,9 +9,6 @@ from torch.utils.hooks import RemovableHandle from collections import defaultdict -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch -from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException -from llmcompressor.utils.fsdp.helpers import register_offload_parameter from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params From c7c8d04aad57a0f74672440e3245ddfc9264d790 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 1 Nov 2024 13:38:56 -0400 Subject: [PATCH 59/59] use observer to calculate qparams --- .../quantization/gptq/utils/gptq_quantize.py | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a2354a242..a625e8a7b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -2,6 +2,7 @@ from copy import copy from typing import Tuple, Union, Optional, Type +from llmcompressor.observers.base import Observer import torch import transformers from compressed_tensors.quantization import ( @@ -59,27 +60,6 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H -def compute_scale_zero_point( - W: torch.Tensor, - quant_args: QuantizationArgs, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Compute the scale and zero point of a module weight - TODO: revisit after observers refactor - - :param W: module weight - :param quant_args: quantization arguments which determine how quantization - parameters are calculated - :return: scale and zero_point - """ - # TODO: revisit after observers refactor - - scale, zero_point = quant_args.get_observer()(W, g_idx=None) - scale = scale.to(dtype=W.dtype) - zero_point = zero_point.to(dtype=quant_args.pytorch_dtype()) - return scale, zero_point - - def quantize_weight( weight: torch.Tensor, inp: torch.Tensor, @@ -107,6 +87,13 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + "minmax", + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + if weight_original is not None: raise NotImplementedError() @@ -131,22 +118,22 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] else: - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) else: - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) # sparsity mask sparsity = tensor_sparsity(W)