forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dynamic_quant.py
85 lines (67 loc) · 2.67 KB
/
dynamic_quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from .utils import (
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
)
__all__ = ["DynamicallyPerAxisQuantizedLinear"]
class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear):
"""
This class is a replacement for `torch.nn.Linear`. It implements a
quantized matmul using int8 dynamic symmetric per-token activation,
and int8 symmetric per-channel weight quantization
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
) -> None:
super().__init__(in_features, out_features, bias)
def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Performs the forward pass of the quantized linear layer which consists
of int8 dynamic symmetric per-token activation and int8 symmetric per-channel weight
quantization
Args:
X (torch.Tensor): The input floating point tensor to the quantized linear layer.
Returns:
torch.Tensor: The output floating point tensor after the quantized matmul and rescale.
"""
Y = quant_int8_dynamic_per_token_linear(
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
)
return Y
@classmethod
def from_float(cls, mod: torch.nn.Linear) -> "DynamicallyPerAxisQuantizedLinear":
"""
Converts a `mod` of class `torch.nn.Linear` to the
`DynamicallyPerAxisQuantizedLinear` class
Args:
mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert.
Returns:
DynamicallyPerAxisQuantizedLinear: The converted quantized linear module.
"""
# create the new module with a toy size to ensure initialization is fast
fake_in_features, fake_out_features = 8, 8
new_mod = cls(
fake_in_features,
fake_out_features,
bias=mod.bias is not None,
)
new_mod.in_features = mod.in_features
new_mod.out_features = mod.out_features
W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel(
mod.weight, -128, 127, torch.int8
)
new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t())
new_mod.W_scales = nn.Parameter(W_scales)
new_mod.bias = mod.bias
del new_mod.weight
device_to_use = next(mod.parameters()).device
new_mod.to(device_to_use)
return new_mod