forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
float8_utils.py
245 lines (197 loc) · 7.87 KB
/
float8_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, Literal, Tuple, Union
import torchao.float8.config as config
import torch
import torch.distributed as dist
# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
# avoid division by zero when calculating scale
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12
IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
}
# User defined type for using the individual F8 type based on config
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
return res.to(torch.float32)
@torch.no_grad()
def amax_history_to_scale(
amax_history: torch.Tensor,
float8_dtype: torch.Tensor,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
):
"""Takes in a history of amax values and returns a scale tensor.
Args:
amax_history: A tensor containing the history of amax values.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
return amax_to_scale(amax, float8_dtype, orig_dtype)
raise NotImplementedError()
@torch.no_grad()
def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor.
Args:
amax_history: A 2D tensor containing a stack of amax histories.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
raise NotImplementedError(
f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
)
@torch.no_grad()
def tensor_to_amax(
x: torch.Tensor, reduce_amax: bool = False, device_mesh=None
) -> torch.Tensor:
amax = torch.max(torch.abs(x))
# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
# happen elsewhere.
if reduce_amax and dist.is_initialized():
pg = device_mesh.get_group() if device_mesh is not None else None
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)
return amax
@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh)
return amax_to_scale(amax, float8_dtype, x.dtype)
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
"""Converts a tensor to a saturated fp8 tensor.
Note:
The default behavior in PyTorch for casting to `float8_e4m3fn`
and `e5m2` is to not saturate. In this context, we should saturate.
A common case where we want to saturate is when the history of a
tensor has a maximum value of `amax1`, and the current amax value
is `amax2`, where `amax1 < amax2`. This is common when using delayed
scaling.
"""
if float8_dtype in FP8_TYPES:
max_value = torch.finfo(float8_dtype).max
x = x.clamp(min=-max_value, max=max_value)
return x.to(float8_dtype)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the error between two tensors in dB.
For more details see:
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
Args:
x: The original tensor.
y: The tensor to compare to the original tensor.
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
def fp8_tensor_statistics(
tensor: torch.Tensor, float8_dtype=e4m3_dtype
) -> Tuple[int, ...]:
"""Calculate FP8 tensor stats
Args:
tensor: The tensor to calculate stats for.
float8_dtype: The float8 dtype.
Returns:
A tuple containing the number of zeros and the number of max values.
"""
if float8_dtype in FP8_TYPES:
FP8_MAX = torch.finfo(float8_dtype).max
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype)
num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item()
num_zero = (tensor_orig_type == 0).sum().item()
return (num_zero, num_max)
def is_row_major(stride):
assert len(stride) == 2, "is_row_major only supports 2D tensors"
return stride[0] > stride[1] and stride[1] == 1
def _get_min_alignment(size: int, alignment_value: int) -> int:
"""
Returns the minimum alignment value that is greater than or equal to the given size.
Args:
size: The size of the data to be aligned.
alignment_value: The alignment value to be used.
Returns:
int: The minimum alignment value that is greater than or equal to the given size.
Usage:
```
>>> _get_min_alignment(10, 8)
16
```
"""
return (1 + ((size - 1) // alignment_value)) * alignment_value
def pad_tensor_for_matmul(
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
) -> torch.Tensor:
"""
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm`
Args:
tensor: The tensor to pad.
dims: Dimensions to pad.
Returns:
torch.Tensor: The padded tensor.
Usage:
```
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape
torch.Size([16, 10])
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape
torch.Size([10, 16])
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape
torch.Size([16, 16])
```
"""
assert tensor.dim() == 2
dim1, dim2 = tensor.shape
if isinstance(dims, int):
dims = (dims,)
# Calculate aligned dimensions based on the specified dims
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2
# Calculate padding values for both dimensions
pad_dim1 = dim1_aligned - dim1
pad_dim2 = dim2_aligned - dim2
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))