Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Observer Restructure]: Add Observers; Add calibration and frozen steps to QuantizationModifier #837

Merged
merged 30 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7d6c73c
update functioon
dsikka Oct 10, 2024
7dad592
wip
dsikka Oct 13, 2024
ece6451
clean-up; fix imports
dsikka Oct 14, 2024
dbda873
clean-up
dsikka Oct 14, 2024
d1a5756
more clean-up
dsikka Oct 14, 2024
15597c3
bug fix
dsikka Oct 15, 2024
acdb8da
update for kvcache
dsikka Oct 17, 2024
28c0167
get kv_cache to work
dsikka Oct 17, 2024
841780d
docstring
dsikka Oct 17, 2024
5e21639
fix comment
dsikka Oct 17, 2024
de28cf8
fix condition for dynamic
dsikka Oct 18, 2024
a3ddb6f
Merge branch 'main' into update-foward
dsikka Oct 18, 2024
b0de448
update
dsikka Oct 18, 2024
b739db8
update tests
dsikka Oct 21, 2024
ac00c9b
add observer tests
dsikka Oct 21, 2024
a5eafad
Merge branch 'main' into update-foward
dsikka Oct 21, 2024
a68694d
add flake8 skip
dsikka Oct 21, 2024
ab2d0a6
apply updated mse fixes
dsikka Oct 22, 2024
27284b8
fix import
dsikka Oct 22, 2024
25a0025
Update src/llmcompressor/modifiers/quantization/calibration.py
dsikka Oct 25, 2024
e574c2a
Update src/llmcompressor/modifiers/quantization/calibration.py
dsikka Oct 25, 2024
43771e7
Merge branch 'main' into update-foward
dsikka Oct 25, 2024
14b69fd
PR comments
dsikka Oct 25, 2024
b4621fa
clean-up
dsikka Oct 25, 2024
92db43e
move hook check to observer call
dsikka Oct 25, 2024
99a9376
Merge branch 'main' into update-foward
dsikka Oct 28, 2024
9fc10a9
update
dsikka Oct 30, 2024
c4686d4
Merge branch 'main' into update-foward
dsikka Oct 30, 2024
e591528
separate out calibration step
dsikka Oct 31, 2024
031ba38
Merge branch 'main' into update-foward
dsikka Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa

from .cache import *
from .gptq import *
from .quantization import *
202 changes: 202 additions & 0 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Dict, List, Optional, Tuple

from compressed_tensors.quantization.lifecycle import KVCacheScaleType
from compressed_tensors.quantization.quant_args import QuantizationArgs
from torch import Tensor
from transformers import DynamicCache as HFDyanmicCache
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved

from llmcompressor.observers import Observer


class QuantizedKVParameterCache(HFDyanmicCache):
"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
Singleton, so that the same cache gets reused in all forward call of self_attn.
Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
gets called appropriately.
The size of tensor is
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.


Triggered by adding kv_cache_scheme in the recipe.

Example:

```python3
recipe = '''
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
'''
dsikka marked this conversation as resolved.
Show resolved Hide resolved

"""

_instance = None
_initialized = False

def __new__(cls, *args, **kwargs):
"""Singleton"""
if cls._instance is None:
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
return cls._instance

def __init__(self, quantization_args: QuantizationArgs):
if not self._initialized:
super().__init__()

self.quantization_args = quantization_args

self.k_observers: List[Observer] = []
self.v_observers: List[Observer] = []

# each index corresponds to layer_idx of the attention layer
self.k_scales: List[Tensor] = []
self.v_scales: List[Tensor] = []

self.k_zps: List[Tensor] = []
self.v_zps: List[Tensor] = []

self._initialized = True

def update(
self,
key_states: Tensor,
value_states: Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the k_scale and v_scale and output the
fakequant-ed key_states and value_states
"""

if len(self.k_observers) <= layer_idx:
k_observer_name = self.quantization_args.get_observer()
k_observer = Observer.load_from_registry(
k_observer_name, quantization_args=self.quantization_args
)
v_observer_name = self.quantization_args.get_observer()
v_observer = Observer.load_from_registry(
v_observer_name, quantization_args=self.quantization_args
)

self.k_observers.append(k_observer)
self.v_observers.append(v_observer)

q_key_states = self._quantize(
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
)
q_value_states = self._quantize(
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
)

qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
qdq_value_states = self._dequantize(
q_value_states, KVCacheScaleType.VALUE, layer_idx
)

keys_to_return, values_to_return = qdq_key_states, qdq_value_states

return keys_to_return, values_to_return

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states.
A layer index can be optionally passed.
"""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
# this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to
# verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

def reset_states(self):
"""reset the kv states (used in calibration)"""
self.key_cache: List[Tensor] = []
self.value_cache: List[Tensor] = []
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0
self._quantized_key_cache: List[Tensor] = []
self._quantized_value_cache: List[Tensor] = []

def reset(self):
"""
Reset the instantiation, create new instance on init
"""
QuantizedKVParameterCache._instance = None
QuantizedKVParameterCache._initialized = False

def _quantize(self, tensor, kv_type, layer_idx):
"""Quantizes a key/value using a defined quantization method."""
from compressed_tensors.quantization.lifecycle.forward import quantize

if kv_type == KVCacheScaleType.KEY: # key type
observer = self.k_observers[layer_idx]
scales = self.k_scales
zps = self.k_zps
else:
assert kv_type == KVCacheScaleType.VALUE
observer = self.v_observers[layer_idx]
scales = self.v_scales
zps = self.v_zps

scale, zp = observer(tensor)
if len(scales) <= layer_idx:
scales.append(scale)
zps.append(zp)
else:
scales[layer_idx] = scale
zps[layer_idx] = scale

q_tensor = quantize(
x=tensor,
scale=scale,
zero_point=zp,
args=self.quantization_args,
)
return q_tensor

def _dequantize(self, qtensor, kv_type, layer_idx):
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
from compressed_tensors.quantization.lifecycle.forward import dequantize

if kv_type == KVCacheScaleType.KEY:
scale = self.k_scales[layer_idx]
zp = self.k_zps[layer_idx]
else:
assert kv_type == KVCacheScaleType.VALUE
scale = self.v_scales[layer_idx]
zp = self.v_zps[layer_idx]

qdq_tensor = dequantize(
x_q=qtensor,
scale=scale,
zero_point=zp,
args=self.quantization_args,
)
return qdq_tensor
Loading
Loading