Skip to content

Commit

Permalink
#14155: Add bfloat8_b tests for moreh_sgd, moreh_linear, moreh_linear…
Browse files Browse the repository at this point in the history
…_backward (#14156)

* #14155: add bfp8_b tests for moreh_sgd, moreh_linear and moreh_linear_backward

* #14155: remove create_ttnn_tilized_tensor
  • Loading branch information
DuongQLee authored Oct 31, 2024
1 parent 0add329 commit fe1129c
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 97 deletions.
213 changes: 185 additions & 28 deletions tests/ttnn/unit_tests/operations/test_moreh_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
import ttnn
from models.utility_functions import comp_allclose_and_pcc, skip_for_grayskull
from tests.ttnn.unit_tests.operations.test_moreh_matmul import get_tensors, get_bias_tensors
from loguru import logger
from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
Expand All @@ -16,24 +15,128 @@
)


def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device):
def get_tensors(
input_shape,
mat2_shape,
output_shape,
require_input_grad,
require_mat2_grad,
device,
*,
torch_dtype=torch.float32,
ttnn_dtype=ttnn.bfloat16,
use_randint=True,
is_1d=False,
low_int=-2,
high_int=3,
):
"""
Returns tensors for input, mat2, output, and their gradients (if required), both in ttnn and torch:
0. ttnn - tilized input tensor
1. ttnn - tilized mat2 tensor
2. ttnn - tilized output tensor
3. ttnn - tilized output gradient tensor (if required), otherwise None
4. ttnn - tilized input gradient tensor (if required), otherwise None
5. ttnn - tilized mat2 gradient tensor (if required), otherwise None
6. torch input tensor
7. torch mat2 tensor
8. torch output gradient tensor (if required), otherwise None
"""
if use_randint:
tensors = [
torch.randint(low=low_int, high=high_int, size=input_shape, dtype=torch_dtype), # input
torch.randint(low=low_int, high=high_int, size=mat2_shape, dtype=torch_dtype), # other
torch.randint(low=low_int, high=high_int, size=output_shape, dtype=torch_dtype), # output
(
torch.randint(low=low_int, high=high_int, size=output_shape, dtype=torch_dtype)
if (require_input_grad or require_mat2_grad)
else None
), # output_grad
torch.full(input_shape, float("nan"), dtype=torch_dtype) if require_input_grad else None, # input_grad
torch.full(mat2_shape, float("nan"), dtype=torch_dtype) if require_mat2_grad else None, # other_grad
]
else:
tensors = [
torch.rand(input_shape, dtype=torch_dtype), # input
torch.rand(mat2_shape, dtype=torch_dtype), # other
torch.rand(output_shape, dtype=torch_dtype), # output
(
torch.rand(output_shape, dtype=torch_dtype) if (require_input_grad or require_mat2_grad) else None
), # output_grad
torch.full(input_shape, float("nan"), dtype=torch_dtype) if require_input_grad else None, # input_grad
torch.full(mat2_shape, float("nan"), dtype=torch_dtype) if require_mat2_grad else None, # other_grad
]

if is_1d:
tensors[0] = tensors[0].reshape(-1)
tensors[1] = tensors[1].reshape(-1)

ttnn_tensors = [
ttnn.from_torch(tensor, device=device, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT)
if tensor is not None
else None
for tensor in tensors
]

return (*ttnn_tensors, tensors[0], tensors[1], tensors[3])


def get_bias_tensors(
bias_shape,
require_bias_grad,
device,
*,
torch_dtype=torch.float32,
ttnn_dtype=ttnn.bfloat16,
use_randint=True,
low_int=-10,
high_int=10,
):
"""
Returns tensors for bias and bias_grad (if required):
0. ttnn - tilized bias tensor
1. torch - bias tensor
2. ttnn - tilized bias grad tensor
"""
if use_randint:
torch_bias = torch.randint(low=low_int, high=high_int, size=bias_shape, dtype=torch_dtype)
else:
torch_bias = torch.rand(size=bias_shape, dtype=torch_dtype) * 10 - 5

tt_bias = ttnn.from_torch(torch_bias, device=device, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT)

tt_bias_grad = None
if require_bias_grad:
bias_grad = torch.full(bias_shape, float("nan"), dtype=torch_dtype)
tt_bias_grad = ttnn.from_torch(bias_grad, device=device, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT)
return tt_bias, torch_bias, tt_bias_grad


def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device, npu_dtype=ttnn.bfloat16):
torch.manual_seed(3072)
cpu_dtype = torch.bfloat16
input_shape, weight_shape, bias_shape, output_shape = shapes
tt_input, tt_weight, _, _, _, _, torch_input, torch_weight, _ = get_tensors(
input_shape, weight_shape, output_shape, False, False, False, device
input_shape,
weight_shape,
output_shape,
False,
False,
is_1d=False,
device=device,
ttnn_dtype=npu_dtype,
torch_dtype=cpu_dtype,
)

npu_dtype = ttnn.bfloat16
npu_layout = ttnn.TILE_LAYOUT
cpu_dtype = torch.bfloat16
torch_output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)
tt_output = ttnn.Tensor(torch_output, npu_dtype).pad_to_tile(1).to(npu_layout).to(device) if has_output else None
tt_output = ttnn.from_torch(torch_output, npu_dtype, device=device, layout=ttnn.TILE_LAYOUT) if has_output else None

if has_bias:
tt_bias, torch_bias, _ = get_bias_tensors(bias_shape, False, device)
tt_bias, torch_bias, _ = get_bias_tensors(
bias_shape, False, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype
)
else:
tt_bias, torch_bias = None, None

## TT Op
tt_output = ttnn.operations.moreh.linear(
tt_input, tt_weight, bias=tt_bias, output=tt_output, compute_kernel_config=compute_kernel_config
Expand All @@ -44,8 +147,7 @@ def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device):

## test for equivalance
rtol = atol = 0.1
cpu_layout = ttnn.ROW_MAJOR_LAYOUT
ttcpu_output = tt_output.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch()
ttcpu_output = ttnn.to_torch(tt_output).to(cpu_dtype)
passing, output_pcc = comp_allclose_and_pcc(torch_output, ttcpu_output, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Passing = {passing}")
logger.debug(f"Output PCC = {output_pcc}")
Expand Down Expand Up @@ -74,10 +176,14 @@ def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device):
)
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_linear(shapes, has_bias, compute_kernel_options, device):
@pytest.mark.parametrize("npu_dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["BFP8", "BFP16"])
def test_moreh_linear(shapes, has_bias, compute_kernel_options, npu_dtype, device):
if npu_dtype == ttnn.bfloat8_b:
# FAILED test cases produce 0.0 and also precision results.
pytest.skip("Moreh Linear does not support bfloat8_b")
torch.manual_seed(3072)
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
passing = moreh_linear(shapes, has_bias, True, compute_kernel_config, device)
passing = moreh_linear(shapes, has_bias, True, compute_kernel_config, device, npu_dtype)
assert passing


Expand All @@ -90,10 +196,14 @@ def test_moreh_linear(shapes, has_bias, compute_kernel_options, device):
),
)
@pytest.mark.parametrize("has_bias", [False, True])
def test_moreh_linear_wo_output(shapes, has_bias, device):
@pytest.mark.parametrize("npu_dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["BFP8", "BFP16"])
def test_moreh_linear_wo_output(shapes, has_bias, npu_dtype, device):
if npu_dtype == ttnn.bfloat8_b:
# FAILED test cases produce 0.0 and also precision results.
pytest.skip("Moreh Linear does not support bfloat8_b")
torch.manual_seed(3072)
compute_kernel_config = get_compute_kernel_options(False)
passing = moreh_linear(shapes, has_bias, False, compute_kernel_config, device)
passing = moreh_linear(shapes, has_bias, False, compute_kernel_config, device, npu_dtype)
assert passing


Expand All @@ -120,12 +230,20 @@ def test_moreh_linear_enable_cache(shapes, device, use_program_cache):


def moreh_linear_backward(
shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device
shapes,
requires_input_grad,
requires_weight_grad,
requires_bias_grad,
compute_kernel_config,
device,
npu_dtype=ttnn.bfloat16,
):
input_shape, weight_shape, bias_shape, output_shape = shapes
if not requires_input_grad and not requires_weight_grad and not requires_bias_grad:
pytest.skip("At least one grad is requires")

cpu_dtype = torch.bfloat16

(
tt_input,
tt_weight,
Expand All @@ -136,9 +254,21 @@ def moreh_linear_backward(
torch_input,
torch_weight,
torch_output_grad,
) = get_tensors(input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device)
) = get_tensors(
input_shape,
weight_shape,
output_shape,
requires_input_grad,
requires_weight_grad,
is_1d=False,
device=device,
ttnn_dtype=npu_dtype,
torch_dtype=cpu_dtype,
)

tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(bias_shape, requires_bias_grad, device)
tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(
bias_shape, requires_bias_grad, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype
)

## tt linear backward
tt_input_grad, tt_weight_grad, tt_bias_grad = ttnn.operations.moreh.linear_backward(
Expand All @@ -162,17 +292,16 @@ def moreh_linear_backward(

## test for equivalance
rtol = atol = 0.1
cpu_layout = ttnn.ROW_MAJOR_LAYOUT
if requires_input_grad:
ttcpu_input_grad = tt_input_grad.cpu().to(cpu_layout).unpad_from_tile(input_shape).to_torch()
ttcpu_input_grad = ttnn.to_torch(tt_input_grad).to(cpu_dtype)
passing, output_pcc = comp_allclose_and_pcc(torch_input.grad, ttcpu_input_grad, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"input_grad passing={passing} pcc={output_pcc}")
assert passing
else:
assert tt_input_grad is None

if requires_weight_grad:
ttcpu_weight_grad = tt_weight_grad.cpu().to(cpu_layout).unpad_from_tile(weight_shape).to_torch()
ttcpu_weight_grad = ttnn.to_torch(tt_weight_grad).to(cpu_dtype)
passing, output_pcc = comp_allclose_and_pcc(
torch_weight.grad, ttcpu_weight_grad, pcc=0.999, rtol=rtol, atol=atol
)
Expand All @@ -182,7 +311,7 @@ def moreh_linear_backward(
assert tt_weight_grad is None

if requires_bias_grad:
ttcpu_bias_grad = tt_bias_grad.cpu().to(cpu_layout).unpad_from_tile(bias_shape).to_torch()
ttcpu_bias_grad = ttnn.to_torch(tt_bias_grad).to(cpu_dtype)
passing, output_pcc = comp_allclose_and_pcc(torch_bias.grad, ttcpu_bias_grad, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"bias_grad passing={passing} pcc={output_pcc}")
assert passing
Expand Down Expand Up @@ -221,12 +350,15 @@ def moreh_linear_backward(
)
@pytest.mark.parametrize("requires_bias_grad", [True, False])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_linear_backward(shapes, requires_grads, requires_bias_grad, compute_kernel_options, device):
@pytest.mark.parametrize("npu_dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["BFP8", "BFP16"])
def test_moreh_linear_backward(shapes, requires_grads, requires_bias_grad, compute_kernel_options, npu_dtype, device):
if npu_dtype == ttnn.bfloat8_b:
pytest.skip("Moreh Linear Backward does not support bfloat8_b")
torch.manual_seed(3072)
requires_input_grad, requires_weight_grad = requires_grads
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
passing = moreh_linear_backward(
shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device
shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device, npu_dtype
)
assert passing

Expand Down Expand Up @@ -276,6 +408,9 @@ def test_moreh_bias_backward_fp32(shapes, device):
compute_kernel_fp32_config = get_compute_kernel_options(True)
compute_kernel_config = get_compute_kernel_options(False)
requires_input_grad, requires_weight_grad, requires_bias_grad = (True, False, True)
npu_dtype = ttnn.bfloat16
cpu_dtype = torch.bfloat16

input_shape, weight_shape, bias_shape, output_shape = shapes
(
tt_input,
Expand All @@ -288,13 +423,35 @@ def test_moreh_bias_backward_fp32(shapes, device):
torch_weight,
torch_output_grad,
) = get_tensors(
input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device, False
input_shape,
weight_shape,
output_shape,
requires_input_grad,
requires_weight_grad,
is_1d=False,
device=device,
ttnn_dtype=npu_dtype,
torch_dtype=cpu_dtype,
use_randint=False,
)
tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(
bias_shape, requires_bias_grad, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype, use_randint=False
)
tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(bias_shape, requires_bias_grad, device, False)
(_, _, _, _, tt_input_grad_fp32, _, _, _, _) = get_tensors(
input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device, False
input_shape,
weight_shape,
output_shape,
requires_input_grad,
requires_weight_grad,
is_1d=False,
device=device,
ttnn_dtype=npu_dtype,
torch_dtype=cpu_dtype,
use_randint=False,
)
(_, _, tt_bias_grad_fp32) = get_bias_tensors(
bias_shape, requires_bias_grad, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype, use_randint=False
)
(_, _, tt_bias_grad_fp32) = get_bias_tensors(bias_shape, requires_bias_grad, device, False)
## tt linear backward (fp32 mode)
tt_input_grad_fp32, _, tt_bias_grad_fp32 = ttnn.operations.moreh.linear_backward(
tt_output_grad,
Expand Down
Loading

0 comments on commit fe1129c

Please sign in to comment.