Skip to content

Commit

Permalink
add observer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Oct 21, 2024
1 parent b739db8 commit ac00c9b
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 14 deletions.
13 changes: 6 additions & 7 deletions tests/llmcompressor/modifiers/calibration/test_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
)
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Linear

from llmcompressor.modifiers.quantization.calibration import (
freeze_module_quantization,
initialize_observer,
)
freeze_module_quantization,
initialize_observer,
)


def test_set_module_for_calibration():
Expand Down
21 changes: 14 additions & 7 deletions tests/llmcompressor/modifiers/calibration/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
QuantizationConfig,
QuantizationStatus,
apply_quantization_config,
is_attention_module
is_attention_module,
)
from llmcompressor.modifiers.quantization.calibration import set_unset_kv_cache, freeze_module_quantization, calibrate_kv_cache_input_hook, calibrate_kv_cache_output_hook
from transformers import AutoModelForCausalLM

from llmcompressor.modifiers.quantization.calibration import (
calibrate_kv_cache_input_hook,
calibrate_kv_cache_output_hook,
freeze_module_quantization,
set_unset_kv_cache,
)

config = {
"quant_method": "compressed-tensors",
"format": "fakequant",
Expand All @@ -43,16 +49,17 @@
},
},
}
_hooks = []


def _prep_for_calibration(module: torch.nn.Module):
if is_attention_module(module):
pre_h = module.register_forward_pre_hook(calibrate_kv_cache_input_hook(), with_kwargs=True)
post_h = module.register_forward_hook(calibrate_kv_cache_output_hook())
_hooks.append(pre_h)
_hooks.append(post_h)
module.register_forward_pre_hook(
calibrate_kv_cache_input_hook(), with_kwargs=True
)
module.register_forward_hook(calibrate_kv_cache_output_hook())
module.quantization_status = QuantizationStatus.CALIBRATION


@pytest.mark.parametrize("config", [config])
def test_kv_cache_quantization(config):
sample = {
Expand Down
13 changes: 13 additions & 0 deletions tests/llmcompressor/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
109 changes: 109 additions & 0 deletions tests/llmcompressor/observers/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationStatus,
apply_quantization_config,
)
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization.calibration import (
calibrate_input_hook,
initialize_observer,
)
from llmcompressor.observers.helpers import get_observer_token_count


def _prep_for_input_quant_calibration(module: torch.nn.Module):
quantization_scheme = getattr(module, "quantization_scheme", None)
if not quantization_scheme:
return

module.register_forward_pre_hook(calibrate_input_hook())
module.quantization_status = QuantizationStatus.CALIBRATION


def test_get_observer_token_count():
model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE")
tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE")
model.eval()
config = QuantizationConfig(
format="fakequant",
quantization_status="calibration",
config_groups={
"group_1": {
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": False,
"strategy": "tensor",
},
"targets": ["Linear"],
},
},
)
apply_quantization_config(model, config)
model.apply(lambda module: initialize_observer(module, base_name="input"))
model.apply(_prep_for_input_quant_calibration)

# start calibration
calib_list = [
"I am a string that",
"is used for calibration so",
"that your model is",
"quantized properly.",
]

total_num_tokens_observed = 0
for calib_sample in calib_list:
calib_tensor = tokenizer(calib_sample, return_tensors="pt")
_ = model(**calib_tensor)
total_num_tokens_observed += len(calib_tensor.input_ids.flatten())

counter = get_observer_token_count(model)

# filter out the None values
# (tokens, in the appropriate format, that were not observed by the model)
counter = {k: v for k, v in counter.items() if v is not None}

# iterate over all the layers in the model where the token count in the proper
# format is has been observed
for i in range(model.config.num_hidden_layers):
# fetch the tokens observed by the router
tokens_observed_by_router = counter.pop(
f"model.layers.{i}.block_sparse_moe.gate"
)
assert tokens_observed_by_router == total_num_tokens_observed

# fetch the sum of tokens observed by all the experts
sum_tokens_observed_by_experts = 0
keys_for_this_layer = [
k
for k in counter.keys()
if f"model.layers.{i}.block_sparse_moe.experts" in k
]
for key in keys_for_this_layer:
sum_tokens_observed_by_experts += counter.pop(key)

# each Mixtral expert is comprised of 3 linear layers,
# so we need to multiply by 3
assert (
sum_tokens_observed_by_experts
== total_num_tokens_observed * model.config.num_experts_per_tok * 3
)

# there are no more information in the counter
assert len(counter) == 0
118 changes: 118 additions & 0 deletions tests/llmcompressor/observers/test_min_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs

from llmcompressor.observers import Observer


def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
perm = torch.randperm(columns)
return torch.tensor([index // group_size for index in range(columns)])[perm]


@pytest.mark.parametrize(
"symmetric,expected_scale,expected_zero_point",
[
(True, 0.0078, 0),
(False, 0.0039, -128),
],
)
def test_min_max_observer(symmetric, expected_scale, expected_zero_point):
tensor = torch.tensor([1, 1, 1, 1, 1])
num_bits = 8
weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric)

observer = weights.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=weights)
scale, zero_point = observer(tensor)

assert round(scale.item(), 4) == expected_scale
assert round(zero_point.item(), 4) == expected_zero_point


def test_min_max_observer_symmetric_scale_range():
tensor = torch.rand(4, 4)
tensor *= 127

num_bits = 8
weights = QuantizationArgs(num_bits=num_bits, symmetric=True)

observer = weights.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=weights)
scale, zero_point = observer(tensor)

# if symmetric, max symmetric_range = abs(-128) / 255
assert round(scale.item(), 4) <= 1.0039
assert round(zero_point.item(), 4) == 0


def test_min_max_observer_value_update():
inp = torch.tensor([1, 1, 1, 1, 1])
inp_update_max = torch.tensor([127, 1, 1, 1, 1])
inp_update_min = torch.tensor([-128, 1, 1, 1, 1])

delta = 1e-6

# update the min, max twice total
tensors = [
inp,
inp,
inp_update_max, # update max
inp,
inp_update_min, # update min
]

tensor = inp
num_bits = 8
weights = QuantizationArgs(num_bits=num_bits, symmetric=True)
observer = weights.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=weights)
curr_max = 1
curr_min = 1
for i, tensor in enumerate(tensors):
observer(tensor)
curr_max = max(observer.max_val.get("default"), curr_max)
curr_min = min(observer.min_val.get("default"), curr_max)

if i < 2:
assert curr_max == 1
assert curr_min == 1
elif i < 4:
assert abs(curr_max - 2.2600) < delta
assert curr_min == 1
else:
assert abs(curr_max - 2.2600) < delta
assert abs(curr_min - (-0.2900)) < delta


def test_g_idx():
group_size = 2
input_shape = (128, 512)
tensor = torch.rand(input_shape)
weights = QuantizationArgs(num_bits=8, group_size=group_size)
g_idx = make_dummy_g_idx(tensor.shape[1], group_size)

observer = weights.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=weights)
scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx)

observer.reset()
scale, zero_point = observer(tensor[:, torch.argsort(g_idx)])

assert scale_g_idx == pytest.approx(scale)
assert zero_point_g_idx == pytest.approx(zero_point)
57 changes: 57 additions & 0 deletions tests/llmcompressor/observers/test_mse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs

from llmcompressor.observers import MovingAverageMSEObserver, Observer


@pytest.mark.parametrize(
"symmetric,expected_scale,expected_zero_point",
[
(True, 0.0078, 0),
(False, 0.0039, -128),
],
)
def test_mse_observer(symmetric, expected_scale, expected_zero_point):
tensor = torch.tensor([1, 1, 1, 1, 1])
num_bits = 8
weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse")

observer = weights.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=weights)
scale, zero_point = observer(tensor)

assert isinstance(observer, MovingAverageMSEObserver)
assert round(scale.item(), 4) == expected_scale
assert round(zero_point.item(), 4) == expected_zero_point


def test_mse_observer_symmetric_scale_range():
tensor = torch.rand(4, 4)
tensor *= 127

num_bits = 8
weights = QuantizationArgs(num_bits=num_bits, symmetric=True)

observer = weights.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=weights)
scale, zero_point = observer(tensor)

# if symmetric, max symmetric_range = abs(-128) / 255
assert round(scale.item(), 4) <= 1.0039
assert round(zero_point.item(), 4) == 0

0 comments on commit ac00c9b

Please sign in to comment.