Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kylesayrs/gptq steps #879

Draft
wants to merge 62 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
98b284b
WIP
kylesayrs Oct 16, 2024
e3a98cc
WIP: begin quantize_weight
kylesayrs Oct 16, 2024
bc9b3bc
WIP
kylesayrs Oct 16, 2024
b77c7bf
WIP
kylesayrs Oct 16, 2024
7be5aed
wip
kylesayrs Oct 16, 2024
e01094f
compilable
kylesayrs Oct 16, 2024
ad9f5a8
compilable
kylesayrs Oct 16, 2024
e4ee0af
wip
kylesayrs Oct 16, 2024
d9ba539
add example
kylesayrs Oct 16, 2024
83a5762
wip
kylesayrs Oct 16, 2024
7f49ab4
runnable
kylesayrs Oct 16, 2024
ac0d926
batching
kylesayrs Oct 21, 2024
6304973
calibration forward context
kylesayrs Oct 21, 2024
868a480
fix stuff
kylesayrs Oct 21, 2024
86c8a06
wip
kylesayrs Oct 21, 2024
1305173
use hooks list
kylesayrs Oct 21, 2024
e6adc5a
layer compressor
kylesayrs Oct 22, 2024
f65f832
style
kylesayrs Oct 22, 2024
1e22569
use layer compressor
kylesayrs Oct 22, 2024
9324695
replicate dtypes
kylesayrs Oct 22, 2024
eef4fb6
write weight changes
kylesayrs Oct 22, 2024
485813a
revert example
kylesayrs Oct 22, 2024
6006155
organization
kylesayrs Oct 22, 2024
c10d2ee
add create_single_batch_dataloader
kylesayrs Oct 22, 2024
6371193
add back empty_cache until I can justify removing it
kylesayrs Oct 22, 2024
92315a5
better type hinting, faster mask applying
kylesayrs Oct 22, 2024
8903fbf
Merge remote-tracking branch 'origin' into kylesayrs/gptq-hooks
kylesayrs Oct 22, 2024
8a25c68
remove breakpoint
kylesayrs Oct 22, 2024
6cd0d6c
apply style, add true_sequential docstring
kylesayrs Oct 22, 2024
0e0c586
update docstring
kylesayrs Oct 22, 2024
d23aabb
use private attrs
kylesayrs Oct 22, 2024
355074b
more docstring
kylesayrs Oct 23, 2024
bf2184d
docstrings
kylesayrs Oct 23, 2024
0b418c7
docstrings
kylesayrs Oct 23, 2024
56cceea
docstrings
kylesayrs Oct 23, 2024
7c7e3bc
move hooksmixin to separate file
kylesayrs Oct 23, 2024
2d52183
docstrings
kylesayrs Oct 23, 2024
d6ff46a
Merge branch 'main' into kylesayrs/gptq-hooks
kylesayrs Oct 23, 2024
9081f12
fix docstring, better arguments grouping
kylesayrs Oct 23, 2024
96e9496
use LayerCompressorMixin
kylesayrs Oct 24, 2024
7fbf8b1
docstrings
kylesayrs Oct 24, 2024
3d3af2a
add back hessian hook to support bs1
kylesayrs Oct 24, 2024
b3021ab
wip
kylesayrs Oct 25, 2024
8508b63
accumulate
kylesayrs Oct 25, 2024
3ff271d
virtualize batches for layers
kylesayrs Oct 25, 2024
d6c6dc3
maybe works, but padding is wrong
kylesayrs Oct 25, 2024
400fa08
WIP
kylesayrs Oct 29, 2024
03515f0
remove hessian
kylesayrs Oct 29, 2024
6e37f64
allocated original weight
kylesayrs Oct 29, 2024
09dae14
proper clone
kylesayrs Oct 29, 2024
944601e
remove breakpoint
kylesayrs Oct 29, 2024
adbcee8
naive_update option
kylesayrs Oct 29, 2024
f4acab2
remove true sequential
kylesayrs Oct 29, 2024
151f566
allow update_offload_parameter to not require data
kylesayrs Oct 29, 2024
76ebc86
bugfix
kylesayrs Oct 29, 2024
3480d6b
ba
kylesayrs Oct 29, 2024
7c55fc5
delete parameter
kylesayrs Oct 29, 2024
0a8004b
sensible generations for small calibration size
kylesayrs Oct 30, 2024
d234b32
remove unnecessary variables
kylesayrs Oct 30, 2024
eeb5c83
remove non-naive updating stuff to focus on naive updating
kylesayrs Oct 30, 2024
99a2d97
Merge remote-tracking branch 'origin' into kylesayrs/gptq-steps
kylesayrs Nov 1, 2024
c7c8d04
use observer to calculate qparams
kylesayrs Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 112 additions & 148 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import math
from compressed_tensors.quantization import (
QuantizationScheme,
disable_quantization,
enable_quantization,
)
from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization
from loguru import logger
from pydantic import Field, field_validator
from torch.nn import Module
from pydantic import Field, PrivateAttr, field_validator

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier, ModifierFactory
from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization
from llmcompressor.modifiers.quantization.gptq.utils import (
GPTQWrapper,
get_output_error,
)
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight
from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier
from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.context import fix_fsdp_module_name
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
qat_active,
from llmcompressor.transformers.finetune.data.data_helpers import (
create_batch_dataloader,
)
from llmcompressor.utils.fsdp.helpers import delete_offload_parameter, register_offload_parameter, update_offload_parameter
from llmcompressor.utils.helpers import (
align_module,
calibration_forward_context,
getattr_chain,
)
from compressed_tensors.quantization import (
fake_quantize,
)

from llmcompressor.utils.pytorch.module import qat_active

__all__ = ["GPTQModifier"]


class GPTQModifier(Modifier):
class GPTQModifier(Modifier, LayerCompressorMixin):
"""
Modifier for applying the one-shot OBCQ algorithm to a model

Expand All @@ -48,6 +52,7 @@ class GPTQModifier(Modifier):
| test_stage:
| obcq_modifiers:
| GPTQModifier:
| true_sequential: False
| dampening_frac: 0.001
| block_size: 128
| config_groups:
Expand All @@ -67,8 +72,8 @@ class GPTQModifier(Modifier):

:param sequential_update: Whether or not to update weights sequentially by layer.
This option is depreciated and setting to False is no longer supported
:param targets: list of layer names to compress during GPTQ, or '__ALL__'
to compress every layer in the model
:param sequential_targets: list of layer names to compress during GPTQ, or
'__ALL__' to compress every layer in the model
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Set to True to quantize using an existing quantization modifier,
or pass in the configuration for a quantization modifier if one does not
Expand Down Expand Up @@ -97,21 +102,22 @@ class GPTQModifier(Modifier):
"""

sequential_update: bool = True # DEPRECIATED
targets: Union[str, List[str], None] = None
batch_size: int = -1
sequential_targets: Union[str, List[str], None] = None
block_size: int = 128
quantize: Union[bool, Dict] = True
dampening_frac: Optional[float] = 0.01
quantize: Union[bool, Dict] = True

# arguments used for quant modifier
config_groups: Optional[Dict[str, QuantizationScheme]] = None
scheme: Optional[Union[str, Dict[str, Any]]] = None
targets: Union[str, List[str], None] = None
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None
scheme: Optional[Union[str, Dict[str, Any]]] = None
disable_quantization_observer_epoch: Optional[float] = None

model: Optional[Any] = None
layer_compressors_: Optional[List[Any]] = None
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None
_quantization_modifier: Optional[QuantizationModifier] = PrivateAttr()
_num_batches: int = PrivateAttr()

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
Expand Down Expand Up @@ -174,8 +180,8 @@ def on_initialize_structure(self, state: State, **kwargs):
self._build_quant_modifier_from_dict(self.quantize)
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)
if self._quantization_modifier:
self._quantization_modifier.on_initialize_structure(state, **kwargs)

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand All @@ -185,143 +191,118 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
"""
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
if self._quantization_modifier:
self._quantization_modifier.initialize(state, **kwargs)
if not self.quantize:
raise ValueError("To use the GPTQModifier, quantization must be enabled.")

if self.batch_size <= 0:
self.batch_size = len(state.data.calib.dataset)
self._num_batches = math.ceil(len(state.data.calib.dataset) / self.batch_size)

modifiable_model = state.model
calibration_dataloader = state.data.calib
self.register_hooks(state.model)
self.calibration_forward(state.model, state.data.calib)

if self.sequential_targets is None:
# if no targets are provided, default to the modules that shouldn't be
# split by FSDP. For Transformers models this is equivalent to the
# decoder layers (ie LlamaDecoderLayer)
self.sequential_targets = get_no_split_params(modifiable_model)
self.remove_hooks()
self.finish_compression(state.model)

self.initialize_compression(modifiable_model, calibration_dataloader)
self.apply_compression(calibration_dataloader)
# freeze quantization
state.model.apply(freeze_module_quantization)

return True

def finish_compression(self, model: torch.nn.Module):
for module in model.modules():
with align_module(module):
quant_args = getattr_chain(module, "quantization_scheme.weights", None)
if quant_args is None:
continue

weight = module.weight_acc / self._num_batches
delete_offload_parameter(module, "weight_acc")

scale, zero_point = quant_args.get_observer()(weight)
weight = fake_quantize(
weight,
scale,
zero_point,
quant_args,
)
update_offload_parameter(module, "weight", weight)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)

def on_finalize(self, state: "State", **kwargs) -> bool:
"""
disable the quantization observers used by the OBCQ algorithm

:param state: session state storing input model and calibration data
"""
if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)
if self._quantization_modifier:
self._quantization_modifier.finalize(state, **kwargs)

return True

def compressible_layers(self) -> Dict:
def calibration_forward(
self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader
):
"""
Retrieves the modules corresponding to a list of
compressible layer names
Perform calibration forward pass with one batch whose size is the size
of the dataset

:precondition: self.model is set and is a torch.nn.Module
:return: dictionary of modules to compress
:param model: model to perform forward pass with
:param dataloader: dataloader containing calibration dataset
"""
if not isinstance(self.model, Module):
raise ValueError(
"`self.model` must be a torch.nn.Module to use "
f"the {self.__class__.__qualname__} modifier but got "
f"{type(self.model)} instead"
)
dataloader = create_batch_dataloader(dataloader, batch_size=self.batch_size)
with calibration_forward_context(model):
run_calibration_forward(model, dataloader, mask_padding=True)

return get_layers(self.sequential_targets, self.model)
def pre_compress_module(self, module: torch.nn.Module):
register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False))

def initialize_compression(
def compress_module(
self,
model: Module,
dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None,
):
name: str,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
) -> float:
"""
Setup for GPTQ, initializes the model
and other parameters, also initilializes the
compressible layers of model, and sets the device
Quantize a module's weight according to the GPTQ algorithm

:param model: model to initialize for compression
:param dataloader: calibration data, not used by GPTQ in this function
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_compressors_ = []

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
name = fix_fsdp_module_name(name)
logger.info(f"Preparing {name} for compression")
args = self._pruning_arguments()
comp_cls = self._compression_class()
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)
self.layer_compressors_.append(compressor)

# for the initial forward data pass, add an early stop exception in order
# to capture inputs right before being compressed by first module
first_layer_compressor = self.layer_compressors_[0]
first_layer_compressor.set_early_stop()

@torch.no_grad()
def apply_compression(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run GPTQ on the loaded model, using dataloader as calibration data
:param name: name of module being quantized
:param module: module being quantized
:param args: input arguments for module forward pass

:param dataloader: calibration data for GPTQ
:return: total loss from applying weight quantization to this module
"""
class_name = self.__class__.__name__.replace("PyTorch", "")
logger.info(
f"Running {class_name} calibration with " f"{len(dataloader)} samples..."
)

# quantization scales and zp are already initialized but we do not
# want to calibrate wrt to these
self.model.apply(disable_quantization)

forward_pass_use_cache = self.model.config.use_cache
self.model.config.use_cache = False

# run_calibration_forward uses the early stop exception to capture values
# as intermediates right before the forward pass of the first module
intermediates = run_calibration_forward(
self.model, dataloader, mask_padding=True
)
self.layer_compressors_[0].clear_early_stop()

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====")

# run the forward pass for each transformer layer (block) one at a time
logger.info(f"Calibrating {layer_compressor.name}...")
layer_compressor.pre_compress()
unquantized_outputs = layer_compressor.calibrate_layer(intermediates)

layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()

# perform a second forward pass of the module to calculate weight-quantized
# outputs for use as inputs to the next layer (block)
quantized_outputs = layer_compressor.calibrate_layer(intermediates)
error = get_output_error(unquantized_outputs, quantized_outputs)
logger.info(f"Mean output error from quantization: {error:.3f}")
intermediates = quantized_outputs
logger.info(f"Quantizing {name}...")

# Assume that first argument is the input
inp = args[0]
quant_args = getattr_chain(module, "quantization_scheme.weights")
logger.info(f"Using {inp.size(0)} samples")

with align_module(module):
loss, quantized_weight, _scale, _zero_point, _g_idx = quantize_weight(
module.weight.data,
inp,
quant_args,
blocksize=self.block_size,
percdamp=self.dampening_frac,
module_class=type(module),
)

self.model.config.use_cache = forward_pass_use_cache
module.weight_acc += quantized_weight
update_offload_parameter(module, "weight_acc")

# re-enable quantization
self.model.apply(enable_quantization)
return loss

def _build_quant_modifier(self):
"""
Build a quantization modifier based on the specified config_groups,
ignore list, and num_calibration_steps.

:postcondition: self.quantization_modifier_ is set to the built
:postcondition: self._quantization_modifier is set to the built
quantization modifier
"""

Expand All @@ -347,26 +328,9 @@ def _build_quant_modifier(self):
def _build_quant_modifier_from_dict(self, quant_config):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
self._quantization_modifier = ModifierFactory.create(
modifier_type,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)

def _pruning_arguments(self):
"""
Gather the parameters needed for root module compression in a dict

:return: dict of params for pruning
"""
return {
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
}

def _compression_class(self):
"""
:return: wrapper class used for root modules of this compression class
"""
return GPTQWrapper
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# flake8: noqa

from .gptq_wrapper import *
from .helpers import *
from .gptq_quantize import *
Loading
Loading