-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Observer Restructure]: Add Observers; Add
calibration
and frozen
…
… steps to `QuantizationModifier` (#837) * update functioon * wip * clean-up; fix imports * clean-up * more clean-up * bug fix * update for kvcache * get kv_cache to work * docstring * fix comment * fix condition for dynamic * update * update tests * add observer tests * add flake8 skip * apply updated mse fixes * fix import * Update src/llmcompressor/modifiers/quantization/calibration.py Co-authored-by: Kyle Sayers <[email protected]> * Update src/llmcompressor/modifiers/quantization/calibration.py Co-authored-by: Kyle Sayers <[email protected]> * PR comments * clean-up * move hook check to observer call * update * separate out calibration step --------- Co-authored-by: Kyle Sayers <[email protected]>
- Loading branch information
Showing
17 changed files
with
1,656 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# flake8: noqa | ||
|
||
from .cache import * | ||
from .gptq import * | ||
from .quantization import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
from compressed_tensors.quantization.lifecycle import KVCacheScaleType | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
from torch import Tensor | ||
from transformers import DynamicCache as HFDyanmicCache | ||
|
||
from llmcompressor.observers import Observer | ||
|
||
|
||
class QuantizedKVParameterCache(HFDyanmicCache): | ||
""" | ||
Quantized KV cache used in the forward call based on HF's dynamic cache. | ||
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy | ||
Singleton, so that the same cache gets reused in all forward call of self_attn. | ||
Each time forward is called, .update() is called, and ._quantize(), ._dequantize() | ||
gets called appropriately. | ||
The size of tensor is | ||
`[batch_size, num_heads, seq_len - residual_length, head_dim]`. | ||
Triggered by adding kv_cache_scheme in the recipe. | ||
Example: | ||
```python3 | ||
recipe = ''' | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
kv_cache_scheme: | ||
num_bits: 8 | ||
type: float | ||
strategy: tensor | ||
dynamic: false | ||
symmetric: true | ||
''' | ||
""" | ||
|
||
_instance = None | ||
_initialized = False | ||
|
||
def __new__(cls, *args, **kwargs): | ||
"""Singleton""" | ||
if cls._instance is None: | ||
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) | ||
return cls._instance | ||
|
||
def __init__(self, quantization_args: QuantizationArgs): | ||
if not self._initialized: | ||
super().__init__() | ||
|
||
self.quantization_args = quantization_args | ||
|
||
self.k_observers: List[Observer] = [] | ||
self.v_observers: List[Observer] = [] | ||
|
||
# each index corresponds to layer_idx of the attention layer | ||
self.k_scales: List[Tensor] = [] | ||
self.v_scales: List[Tensor] = [] | ||
|
||
self.k_zps: List[Tensor] = [] | ||
self.v_zps: List[Tensor] = [] | ||
|
||
self._initialized = True | ||
|
||
def update( | ||
self, | ||
key_states: Tensor, | ||
value_states: Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Get the k_scale and v_scale and output the | ||
fakequant-ed key_states and value_states | ||
""" | ||
|
||
if len(self.k_observers) <= layer_idx: | ||
k_observer_name = self.quantization_args.get_observer() | ||
k_observer = Observer.load_from_registry( | ||
k_observer_name, quantization_args=self.quantization_args | ||
) | ||
v_observer_name = self.quantization_args.get_observer() | ||
v_observer = Observer.load_from_registry( | ||
v_observer_name, quantization_args=self.quantization_args | ||
) | ||
|
||
self.k_observers.append(k_observer) | ||
self.v_observers.append(v_observer) | ||
|
||
q_key_states = self._quantize( | ||
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx | ||
) | ||
q_value_states = self._quantize( | ||
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx | ||
) | ||
|
||
qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) | ||
qdq_value_states = self._dequantize( | ||
q_value_states, KVCacheScaleType.VALUE, layer_idx | ||
) | ||
|
||
keys_to_return, values_to_return = qdq_key_states, qdq_value_states | ||
|
||
return keys_to_return, values_to_return | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
""" | ||
Returns the sequence length of the cached states. | ||
A layer index can be optionally passed. | ||
""" | ||
if len(self.key_cache) <= layer_idx: | ||
return 0 | ||
# since we cannot get the seq_length of each layer directly and | ||
# rely on `_seen_tokens` which is updated every "layer_idx" == 0, | ||
# this is a hack to get the actual seq_length for the given layer_idx | ||
# this part of code otherwise fails when used to | ||
# verify attn_weight shape in some models | ||
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 | ||
|
||
def reset_states(self): | ||
"""reset the kv states (used in calibration)""" | ||
self.key_cache: List[Tensor] = [] | ||
self.value_cache: List[Tensor] = [] | ||
# Used in `generate` to keep tally of how many tokens the cache has seen | ||
self._seen_tokens = 0 | ||
self._quantized_key_cache: List[Tensor] = [] | ||
self._quantized_value_cache: List[Tensor] = [] | ||
|
||
def reset(self): | ||
""" | ||
Reset the instantiation, create new instance on init | ||
""" | ||
QuantizedKVParameterCache._instance = None | ||
QuantizedKVParameterCache._initialized = False | ||
|
||
def _quantize(self, tensor, kv_type, layer_idx): | ||
"""Quantizes a key/value using a defined quantization method.""" | ||
from compressed_tensors.quantization.lifecycle.forward import quantize | ||
|
||
if kv_type == KVCacheScaleType.KEY: # key type | ||
observer = self.k_observers[layer_idx] | ||
scales = self.k_scales | ||
zps = self.k_zps | ||
else: | ||
assert kv_type == KVCacheScaleType.VALUE | ||
observer = self.v_observers[layer_idx] | ||
scales = self.v_scales | ||
zps = self.v_zps | ||
|
||
scale, zp = observer(tensor) | ||
if len(scales) <= layer_idx: | ||
scales.append(scale) | ||
zps.append(zp) | ||
else: | ||
scales[layer_idx] = scale | ||
zps[layer_idx] = scale | ||
|
||
q_tensor = quantize( | ||
x=tensor, | ||
scale=scale, | ||
zero_point=zp, | ||
args=self.quantization_args, | ||
) | ||
return q_tensor | ||
|
||
def _dequantize(self, qtensor, kv_type, layer_idx): | ||
"""Dequantizes back the tensor that was quantized by `self._quantize()`""" | ||
from compressed_tensors.quantization.lifecycle.forward import dequantize | ||
|
||
if kv_type == KVCacheScaleType.KEY: | ||
scale = self.k_scales[layer_idx] | ||
zp = self.k_zps[layer_idx] | ||
else: | ||
assert kv_type == KVCacheScaleType.VALUE | ||
scale = self.v_scales[layer_idx] | ||
zp = self.v_zps[layer_idx] | ||
|
||
qdq_tensor = dequantize( | ||
x_q=qtensor, | ||
scale=scale, | ||
zero_point=zp, | ||
args=self.quantization_args, | ||
) | ||
return qdq_tensor |
Oops, something went wrong.