-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
317 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from compressed_tensors.quantization import ( | ||
QuantizationConfig, | ||
QuantizationStatus, | ||
apply_quantization_config, | ||
) | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor.modifiers.quantization.calibration import ( | ||
calibrate_input_hook, | ||
initialize_observer, | ||
) | ||
from llmcompressor.observers.helpers import get_observer_token_count | ||
|
||
|
||
def _prep_for_input_quant_calibration(module: torch.nn.Module): | ||
quantization_scheme = getattr(module, "quantization_scheme", None) | ||
if not quantization_scheme: | ||
return | ||
|
||
module.register_forward_pre_hook(calibrate_input_hook()) | ||
module.quantization_status = QuantizationStatus.CALIBRATION | ||
|
||
|
||
def test_get_observer_token_count(): | ||
model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") | ||
tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") | ||
model.eval() | ||
config = QuantizationConfig( | ||
format="fakequant", | ||
quantization_status="calibration", | ||
config_groups={ | ||
"group_1": { | ||
"input_activations": { | ||
"num_bits": 8, | ||
"type": "int", | ||
"symmetric": False, | ||
"strategy": "tensor", | ||
}, | ||
"targets": ["Linear"], | ||
}, | ||
}, | ||
) | ||
apply_quantization_config(model, config) | ||
model.apply(lambda module: initialize_observer(module, base_name="input")) | ||
model.apply(_prep_for_input_quant_calibration) | ||
|
||
# start calibration | ||
calib_list = [ | ||
"I am a string that", | ||
"is used for calibration so", | ||
"that your model is", | ||
"quantized properly.", | ||
] | ||
|
||
total_num_tokens_observed = 0 | ||
for calib_sample in calib_list: | ||
calib_tensor = tokenizer(calib_sample, return_tensors="pt") | ||
_ = model(**calib_tensor) | ||
total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) | ||
|
||
counter = get_observer_token_count(model) | ||
|
||
# filter out the None values | ||
# (tokens, in the appropriate format, that were not observed by the model) | ||
counter = {k: v for k, v in counter.items() if v is not None} | ||
|
||
# iterate over all the layers in the model where the token count in the proper | ||
# format is has been observed | ||
for i in range(model.config.num_hidden_layers): | ||
# fetch the tokens observed by the router | ||
tokens_observed_by_router = counter.pop( | ||
f"model.layers.{i}.block_sparse_moe.gate" | ||
) | ||
assert tokens_observed_by_router == total_num_tokens_observed | ||
|
||
# fetch the sum of tokens observed by all the experts | ||
sum_tokens_observed_by_experts = 0 | ||
keys_for_this_layer = [ | ||
k | ||
for k in counter.keys() | ||
if f"model.layers.{i}.block_sparse_moe.experts" in k | ||
] | ||
for key in keys_for_this_layer: | ||
sum_tokens_observed_by_experts += counter.pop(key) | ||
|
||
# each Mixtral expert is comprised of 3 linear layers, | ||
# so we need to multiply by 3 | ||
assert ( | ||
sum_tokens_observed_by_experts | ||
== total_num_tokens_observed * model.config.num_experts_per_tok * 3 | ||
) | ||
|
||
# there are no more information in the counter | ||
assert len(counter) == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import pytest | ||
import torch | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
|
||
from llmcompressor.observers import Observer | ||
|
||
|
||
def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: | ||
perm = torch.randperm(columns) | ||
return torch.tensor([index // group_size for index in range(columns)])[perm] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"symmetric,expected_scale,expected_zero_point", | ||
[ | ||
(True, 0.0078, 0), | ||
(False, 0.0039, -128), | ||
], | ||
) | ||
def test_min_max_observer(symmetric, expected_scale, expected_zero_point): | ||
tensor = torch.tensor([1, 1, 1, 1, 1]) | ||
num_bits = 8 | ||
weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric) | ||
|
||
observer = weights.get_observer() | ||
observer = Observer.load_from_registry(observer, quantization_args=weights) | ||
scale, zero_point = observer(tensor) | ||
|
||
assert round(scale.item(), 4) == expected_scale | ||
assert round(zero_point.item(), 4) == expected_zero_point | ||
|
||
|
||
def test_min_max_observer_symmetric_scale_range(): | ||
tensor = torch.rand(4, 4) | ||
tensor *= 127 | ||
|
||
num_bits = 8 | ||
weights = QuantizationArgs(num_bits=num_bits, symmetric=True) | ||
|
||
observer = weights.get_observer() | ||
observer = Observer.load_from_registry(observer, quantization_args=weights) | ||
scale, zero_point = observer(tensor) | ||
|
||
# if symmetric, max symmetric_range = abs(-128) / 255 | ||
assert round(scale.item(), 4) <= 1.0039 | ||
assert round(zero_point.item(), 4) == 0 | ||
|
||
|
||
def test_min_max_observer_value_update(): | ||
inp = torch.tensor([1, 1, 1, 1, 1]) | ||
inp_update_max = torch.tensor([127, 1, 1, 1, 1]) | ||
inp_update_min = torch.tensor([-128, 1, 1, 1, 1]) | ||
|
||
delta = 1e-6 | ||
|
||
# update the min, max twice total | ||
tensors = [ | ||
inp, | ||
inp, | ||
inp_update_max, # update max | ||
inp, | ||
inp_update_min, # update min | ||
] | ||
|
||
tensor = inp | ||
num_bits = 8 | ||
weights = QuantizationArgs(num_bits=num_bits, symmetric=True) | ||
observer = weights.get_observer() | ||
observer = Observer.load_from_registry(observer, quantization_args=weights) | ||
curr_max = 1 | ||
curr_min = 1 | ||
for i, tensor in enumerate(tensors): | ||
observer(tensor) | ||
curr_max = max(observer.max_val.get("default"), curr_max) | ||
curr_min = min(observer.min_val.get("default"), curr_max) | ||
|
||
if i < 2: | ||
assert curr_max == 1 | ||
assert curr_min == 1 | ||
elif i < 4: | ||
assert abs(curr_max - 2.2600) < delta | ||
assert curr_min == 1 | ||
else: | ||
assert abs(curr_max - 2.2600) < delta | ||
assert abs(curr_min - (-0.2900)) < delta | ||
|
||
|
||
def test_g_idx(): | ||
group_size = 2 | ||
input_shape = (128, 512) | ||
tensor = torch.rand(input_shape) | ||
weights = QuantizationArgs(num_bits=8, group_size=group_size) | ||
g_idx = make_dummy_g_idx(tensor.shape[1], group_size) | ||
|
||
observer = weights.get_observer() | ||
observer = Observer.load_from_registry(observer, quantization_args=weights) | ||
scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) | ||
|
||
observer.reset() | ||
scale, zero_point = observer(tensor[:, torch.argsort(g_idx)]) | ||
|
||
assert scale_g_idx == pytest.approx(scale) | ||
assert zero_point_g_idx == pytest.approx(zero_point) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import pytest | ||
import torch | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
|
||
from llmcompressor.observers import MovingAverageMSEObserver, Observer | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"symmetric,expected_scale,expected_zero_point", | ||
[ | ||
(True, 0.0078, 0), | ||
(False, 0.0039, -128), | ||
], | ||
) | ||
def test_mse_observer(symmetric, expected_scale, expected_zero_point): | ||
tensor = torch.tensor([1, 1, 1, 1, 1]) | ||
num_bits = 8 | ||
weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") | ||
|
||
observer = weights.get_observer() | ||
observer = Observer.load_from_registry(observer, quantization_args=weights) | ||
scale, zero_point = observer(tensor) | ||
|
||
assert isinstance(observer, MovingAverageMSEObserver) | ||
assert round(scale.item(), 4) == expected_scale | ||
assert round(zero_point.item(), 4) == expected_zero_point | ||
|
||
|
||
def test_mse_observer_symmetric_scale_range(): | ||
tensor = torch.rand(4, 4) | ||
tensor *= 127 | ||
|
||
num_bits = 8 | ||
weights = QuantizationArgs(num_bits=num_bits, symmetric=True) | ||
|
||
observer = weights.get_observer() | ||
observer = Observer.load_from_registry(observer, quantization_args=weights) | ||
scale, zero_point = observer(tensor) | ||
|
||
# if symmetric, max symmetric_range = abs(-128) / 255 | ||
assert round(scale.item(), 4) <= 1.0039 | ||
assert round(zero_point.item(), 4) == 0 |