diff --git a/tests/ttnn/unit_tests/operations/test_moreh_linear.py b/tests/ttnn/unit_tests/operations/test_moreh_linear.py index 1f8ab7cb592..d5368466b29 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_linear.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_linear.py @@ -6,7 +6,6 @@ import torch import ttnn from models.utility_functions import comp_allclose_and_pcc, skip_for_grayskull -from tests.ttnn.unit_tests.operations.test_moreh_matmul import get_tensors, get_bias_tensors from loguru import logger from tests.ttnn.unit_tests.operations.test_utils import ( get_compute_kernel_options, @@ -16,24 +15,128 @@ ) -def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device): +def get_tensors( + input_shape, + mat2_shape, + output_shape, + require_input_grad, + require_mat2_grad, + device, + *, + torch_dtype=torch.float32, + ttnn_dtype=ttnn.bfloat16, + use_randint=True, + is_1d=False, + low_int=-2, + high_int=3, +): + """ + Returns tensors for input, mat2, output, and their gradients (if required), both in ttnn and torch: + 0. ttnn - tilized input tensor + 1. ttnn - tilized mat2 tensor + 2. ttnn - tilized output tensor + 3. ttnn - tilized output gradient tensor (if required), otherwise None + 4. ttnn - tilized input gradient tensor (if required), otherwise None + 5. ttnn - tilized mat2 gradient tensor (if required), otherwise None + 6. torch input tensor + 7. torch mat2 tensor + 8. torch output gradient tensor (if required), otherwise None + """ + if use_randint: + tensors = [ + torch.randint(low=low_int, high=high_int, size=input_shape, dtype=torch_dtype), # input + torch.randint(low=low_int, high=high_int, size=mat2_shape, dtype=torch_dtype), # other + torch.randint(low=low_int, high=high_int, size=output_shape, dtype=torch_dtype), # output + ( + torch.randint(low=low_int, high=high_int, size=output_shape, dtype=torch_dtype) + if (require_input_grad or require_mat2_grad) + else None + ), # output_grad + torch.full(input_shape, float("nan"), dtype=torch_dtype) if require_input_grad else None, # input_grad + torch.full(mat2_shape, float("nan"), dtype=torch_dtype) if require_mat2_grad else None, # other_grad + ] + else: + tensors = [ + torch.rand(input_shape, dtype=torch_dtype), # input + torch.rand(mat2_shape, dtype=torch_dtype), # other + torch.rand(output_shape, dtype=torch_dtype), # output + ( + torch.rand(output_shape, dtype=torch_dtype) if (require_input_grad or require_mat2_grad) else None + ), # output_grad + torch.full(input_shape, float("nan"), dtype=torch_dtype) if require_input_grad else None, # input_grad + torch.full(mat2_shape, float("nan"), dtype=torch_dtype) if require_mat2_grad else None, # other_grad + ] + + if is_1d: + tensors[0] = tensors[0].reshape(-1) + tensors[1] = tensors[1].reshape(-1) + + ttnn_tensors = [ + ttnn.from_torch(tensor, device=device, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT) + if tensor is not None + else None + for tensor in tensors + ] + + return (*ttnn_tensors, tensors[0], tensors[1], tensors[3]) + + +def get_bias_tensors( + bias_shape, + require_bias_grad, + device, + *, + torch_dtype=torch.float32, + ttnn_dtype=ttnn.bfloat16, + use_randint=True, + low_int=-10, + high_int=10, +): + """ + Returns tensors for bias and bias_grad (if required): + 0. ttnn - tilized bias tensor + 1. torch - bias tensor + 2. ttnn - tilized bias grad tensor + """ + if use_randint: + torch_bias = torch.randint(low=low_int, high=high_int, size=bias_shape, dtype=torch_dtype) + else: + torch_bias = torch.rand(size=bias_shape, dtype=torch_dtype) * 10 - 5 + + tt_bias = ttnn.from_torch(torch_bias, device=device, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT) + + tt_bias_grad = None + if require_bias_grad: + bias_grad = torch.full(bias_shape, float("nan"), dtype=torch_dtype) + tt_bias_grad = ttnn.from_torch(bias_grad, device=device, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT) + return tt_bias, torch_bias, tt_bias_grad + + +def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device, npu_dtype=ttnn.bfloat16): torch.manual_seed(3072) + cpu_dtype = torch.bfloat16 input_shape, weight_shape, bias_shape, output_shape = shapes tt_input, tt_weight, _, _, _, _, torch_input, torch_weight, _ = get_tensors( - input_shape, weight_shape, output_shape, False, False, False, device + input_shape, + weight_shape, + output_shape, + False, + False, + is_1d=False, + device=device, + ttnn_dtype=npu_dtype, + torch_dtype=cpu_dtype, ) - npu_dtype = ttnn.bfloat16 - npu_layout = ttnn.TILE_LAYOUT - cpu_dtype = torch.bfloat16 torch_output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype) - tt_output = ttnn.Tensor(torch_output, npu_dtype).pad_to_tile(1).to(npu_layout).to(device) if has_output else None + tt_output = ttnn.from_torch(torch_output, npu_dtype, device=device, layout=ttnn.TILE_LAYOUT) if has_output else None if has_bias: - tt_bias, torch_bias, _ = get_bias_tensors(bias_shape, False, device) + tt_bias, torch_bias, _ = get_bias_tensors( + bias_shape, False, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype + ) else: tt_bias, torch_bias = None, None - ## TT Op tt_output = ttnn.operations.moreh.linear( tt_input, tt_weight, bias=tt_bias, output=tt_output, compute_kernel_config=compute_kernel_config @@ -44,8 +147,7 @@ def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device): ## test for equivalance rtol = atol = 0.1 - cpu_layout = ttnn.ROW_MAJOR_LAYOUT - ttcpu_output = tt_output.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + ttcpu_output = ttnn.to_torch(tt_output).to(cpu_dtype) passing, output_pcc = comp_allclose_and_pcc(torch_output, ttcpu_output, pcc=0.999, rtol=rtol, atol=atol) logger.debug(f"Passing = {passing}") logger.debug(f"Output PCC = {output_pcc}") @@ -74,10 +176,14 @@ def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device): ) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_moreh_linear(shapes, has_bias, compute_kernel_options, device): +@pytest.mark.parametrize("npu_dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["BFP8", "BFP16"]) +def test_moreh_linear(shapes, has_bias, compute_kernel_options, npu_dtype, device): + if npu_dtype == ttnn.bfloat8_b: + # FAILED test cases produce 0.0 and also precision results. + pytest.skip("Moreh Linear does not support bfloat8_b") torch.manual_seed(3072) compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - passing = moreh_linear(shapes, has_bias, True, compute_kernel_config, device) + passing = moreh_linear(shapes, has_bias, True, compute_kernel_config, device, npu_dtype) assert passing @@ -90,10 +196,14 @@ def test_moreh_linear(shapes, has_bias, compute_kernel_options, device): ), ) @pytest.mark.parametrize("has_bias", [False, True]) -def test_moreh_linear_wo_output(shapes, has_bias, device): +@pytest.mark.parametrize("npu_dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["BFP8", "BFP16"]) +def test_moreh_linear_wo_output(shapes, has_bias, npu_dtype, device): + if npu_dtype == ttnn.bfloat8_b: + # FAILED test cases produce 0.0 and also precision results. + pytest.skip("Moreh Linear does not support bfloat8_b") torch.manual_seed(3072) compute_kernel_config = get_compute_kernel_options(False) - passing = moreh_linear(shapes, has_bias, False, compute_kernel_config, device) + passing = moreh_linear(shapes, has_bias, False, compute_kernel_config, device, npu_dtype) assert passing @@ -120,12 +230,20 @@ def test_moreh_linear_enable_cache(shapes, device, use_program_cache): def moreh_linear_backward( - shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device + shapes, + requires_input_grad, + requires_weight_grad, + requires_bias_grad, + compute_kernel_config, + device, + npu_dtype=ttnn.bfloat16, ): input_shape, weight_shape, bias_shape, output_shape = shapes if not requires_input_grad and not requires_weight_grad and not requires_bias_grad: pytest.skip("At least one grad is requires") + cpu_dtype = torch.bfloat16 + ( tt_input, tt_weight, @@ -136,9 +254,21 @@ def moreh_linear_backward( torch_input, torch_weight, torch_output_grad, - ) = get_tensors(input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device) + ) = get_tensors( + input_shape, + weight_shape, + output_shape, + requires_input_grad, + requires_weight_grad, + is_1d=False, + device=device, + ttnn_dtype=npu_dtype, + torch_dtype=cpu_dtype, + ) - tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(bias_shape, requires_bias_grad, device) + tt_bias, torch_bias, tt_bias_grad = get_bias_tensors( + bias_shape, requires_bias_grad, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype + ) ## tt linear backward tt_input_grad, tt_weight_grad, tt_bias_grad = ttnn.operations.moreh.linear_backward( @@ -162,9 +292,8 @@ def moreh_linear_backward( ## test for equivalance rtol = atol = 0.1 - cpu_layout = ttnn.ROW_MAJOR_LAYOUT if requires_input_grad: - ttcpu_input_grad = tt_input_grad.cpu().to(cpu_layout).unpad_from_tile(input_shape).to_torch() + ttcpu_input_grad = ttnn.to_torch(tt_input_grad).to(cpu_dtype) passing, output_pcc = comp_allclose_and_pcc(torch_input.grad, ttcpu_input_grad, pcc=0.999, rtol=rtol, atol=atol) logger.debug(f"input_grad passing={passing} pcc={output_pcc}") assert passing @@ -172,7 +301,7 @@ def moreh_linear_backward( assert tt_input_grad is None if requires_weight_grad: - ttcpu_weight_grad = tt_weight_grad.cpu().to(cpu_layout).unpad_from_tile(weight_shape).to_torch() + ttcpu_weight_grad = ttnn.to_torch(tt_weight_grad).to(cpu_dtype) passing, output_pcc = comp_allclose_and_pcc( torch_weight.grad, ttcpu_weight_grad, pcc=0.999, rtol=rtol, atol=atol ) @@ -182,7 +311,7 @@ def moreh_linear_backward( assert tt_weight_grad is None if requires_bias_grad: - ttcpu_bias_grad = tt_bias_grad.cpu().to(cpu_layout).unpad_from_tile(bias_shape).to_torch() + ttcpu_bias_grad = ttnn.to_torch(tt_bias_grad).to(cpu_dtype) passing, output_pcc = comp_allclose_and_pcc(torch_bias.grad, ttcpu_bias_grad, pcc=0.999, rtol=rtol, atol=atol) logger.debug(f"bias_grad passing={passing} pcc={output_pcc}") assert passing @@ -221,12 +350,15 @@ def moreh_linear_backward( ) @pytest.mark.parametrize("requires_bias_grad", [True, False]) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_moreh_linear_backward(shapes, requires_grads, requires_bias_grad, compute_kernel_options, device): +@pytest.mark.parametrize("npu_dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["BFP8", "BFP16"]) +def test_moreh_linear_backward(shapes, requires_grads, requires_bias_grad, compute_kernel_options, npu_dtype, device): + if npu_dtype == ttnn.bfloat8_b: + pytest.skip("Moreh Linear Backward does not support bfloat8_b") torch.manual_seed(3072) requires_input_grad, requires_weight_grad = requires_grads compute_kernel_config = get_compute_kernel_options(compute_kernel_options) passing = moreh_linear_backward( - shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device + shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device, npu_dtype ) assert passing @@ -276,6 +408,9 @@ def test_moreh_bias_backward_fp32(shapes, device): compute_kernel_fp32_config = get_compute_kernel_options(True) compute_kernel_config = get_compute_kernel_options(False) requires_input_grad, requires_weight_grad, requires_bias_grad = (True, False, True) + npu_dtype = ttnn.bfloat16 + cpu_dtype = torch.bfloat16 + input_shape, weight_shape, bias_shape, output_shape = shapes ( tt_input, @@ -288,13 +423,35 @@ def test_moreh_bias_backward_fp32(shapes, device): torch_weight, torch_output_grad, ) = get_tensors( - input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device, False + input_shape, + weight_shape, + output_shape, + requires_input_grad, + requires_weight_grad, + is_1d=False, + device=device, + ttnn_dtype=npu_dtype, + torch_dtype=cpu_dtype, + use_randint=False, + ) + tt_bias, torch_bias, tt_bias_grad = get_bias_tensors( + bias_shape, requires_bias_grad, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype, use_randint=False ) - tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(bias_shape, requires_bias_grad, device, False) (_, _, _, _, tt_input_grad_fp32, _, _, _, _) = get_tensors( - input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device, False + input_shape, + weight_shape, + output_shape, + requires_input_grad, + requires_weight_grad, + is_1d=False, + device=device, + ttnn_dtype=npu_dtype, + torch_dtype=cpu_dtype, + use_randint=False, + ) + (_, _, tt_bias_grad_fp32) = get_bias_tensors( + bias_shape, requires_bias_grad, device, torch_dtype=cpu_dtype, ttnn_dtype=npu_dtype, use_randint=False ) - (_, _, tt_bias_grad_fp32) = get_bias_tensors(bias_shape, requires_bias_grad, device, False) ## tt linear backward (fp32 mode) tt_input_grad_fp32, _, tt_bias_grad_fp32 = ttnn.operations.moreh.linear_backward( tt_output_grad, diff --git a/tests/ttnn/unit_tests/operations/test_moreh_sgd.py b/tests/ttnn/unit_tests/operations/test_moreh_sgd.py index 49bc7657080..841da226892 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_sgd.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_sgd.py @@ -9,51 +9,22 @@ import ttnn import pytest from models.utility_functions import comp_allclose_and_pcc, is_wormhole_b0 +from tests.ttnn.unit_tests.operations.test_utils import ( + get_compute_kernel_options, + compute_kernel_options, + compute_kernel_ids, +) from loguru import logger fp32_dest_acc_en = [ False, # for grayskull ] fp32_dest_acc_en_ids = ["fp32_dest_acc_en=False"] -if is_wormhole_b0: +if is_wormhole_b0(): fp32_dest_acc_en.append(True) fp32_dest_acc_en_ids.append("fp32_dest_acc_en=True") -def get_compute_kernel_options(fp32_dest_acc_en): - if fp32_dest_acc_en is None: - return None - - if is_wormhole_b0(): - packer_l1_acc = False - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=fp32_dest_acc_en, - packer_l1_acc=packer_l1_acc, - ) - else: - # Grayskull doesn't support fp32 but test passing a GS config is ok - compute_kernel_config = ttnn.GrayskullComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=True, - ) - return compute_kernel_config - - -def create_tt_tensor(tensor, device): - ret = ( - ttnn.Tensor( - tensor, - ttnn.bfloat16, - ) - .to(ttnn.TILE_LAYOUT) - .to(device) - ) - - return ret - - @pytest.mark.parametrize( "shape", [ @@ -71,6 +42,11 @@ def create_tt_tensor(tensor, device): ) @pytest.mark.parametrize("has_param_out", [True, False], ids=["HAS_PARAM_OUT_TRUE", "HAS_PARAM_OUT_FALSE"]) @pytest.mark.parametrize("fp32_dest_acc_en", fp32_dest_acc_en, ids=fp32_dest_acc_en_ids) +@pytest.mark.parametrize( + "npu_dtype, cpu_dtype", + [[ttnn.bfloat8_b, torch.bfloat16], [ttnn.bfloat16, torch.bfloat16]], + ids=["bfloat8", "bfloat16"], +) def test_moreh_sgd( shape, lr, @@ -81,23 +57,29 @@ def test_moreh_sgd( momentum_initialized, has_param_out, fp32_dest_acc_en, + npu_dtype, + cpu_dtype, device, ): if nesterov and (momentum <= 0 or dampening != 0): pytest.skip() + if npu_dtype == ttnn.bfloat8_b: + # Duong: ttnn.bfloat8_b has some bugs. only around half the tests passed for bfloat8_b. Some tests produce 0.0 or Inf results. + # I couldn't identify the pattern of failed tests, it seems kind of random so I think it's a precision error. + pytest.skip() torch.manual_seed(0) compute_kernel_config = get_compute_kernel_options(fp32_dest_acc_en) # make model and compute grad - x_data = torch.rand(shape).to(torch.bfloat16) - y_data = torch.rand(shape).to(torch.bfloat16) + x_data = torch.rand(shape).to(cpu_dtype) + y_data = torch.rand(shape).to(cpu_dtype) class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() - self.weight = nn.Parameter(torch.randn(shape).to(torch.bfloat16)).to(torch.bfloat16) + self.weight = nn.Parameter(torch.randn(shape).to(cpu_dtype)).to(cpu_dtype) def forward(self, x): return torch.mul(x, self.weight) @@ -121,7 +103,7 @@ def forward(self, x): cpu_momentum_out = None for i in range(0, step_cnt): cpu_param_in = model.weight.clone() - dev_param_in = create_tt_tensor(cpu_param_in, device) + dev_param_in = ttnn.from_torch(cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) optimizer_state_dict = optimizer.state_dict() if momentum != 0: @@ -136,21 +118,25 @@ def forward(self, x): cpu_momentum_out = optimizer_state_dict["state"][0]["momentum_buffer"].clone() # create other dev tensors - dev_param_out = create_tt_tensor(cpu_param_in, device) - + dev_param_out = ttnn.from_torch(cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) cpu_grad = model.weight.grad - dev_grad = create_tt_tensor(cpu_grad, device) + + dev_grad = ttnn.from_torch(cpu_grad, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) dev_momentum_buffer_in = None dev_momentum_buffer_out = None if momentum != 0: if momentum_initialized: if cpu_momentum_in is not None: - dev_momentum_buffer_in = create_tt_tensor(cpu_momentum_in, device) + dev_momentum_buffer_in = ttnn.from_torch( + cpu_momentum_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) else: - dev_momentum_buffer_in = create_tt_tensor(cpu_param_in, device) + dev_momentum_buffer_in = ttnn.from_torch( + cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) - dev_momentum_buffer_out = create_tt_tensor(cpu_param_in, device) + dev_momentum_buffer_out = ttnn.from_torch(cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) dev_param_out, dev_momentum_buffer_out = ttnn.operations.moreh.sgd( dev_param_in, @@ -170,7 +156,7 @@ def forward(self, x): assert dev_param_in.shape == list(model.weight.shape) # check param_out - param_result = dev_param_out.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + param_result = ttnn.to_torch(dev_param_out).to(cpu_dtype) rtol = atol = 0.05 passing, out = comp_allclose_and_pcc(model.weight, param_result, pcc=0.99, rtol=rtol, atol=atol) @@ -182,7 +168,7 @@ def forward(self, x): # check momentum_out if momentum != 0: - momentum_buffer_result = dev_momentum_buffer_out.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + momentum_buffer_result = ttnn.to_torch(dev_momentum_buffer_out).to(cpu_dtype) passing, out = comp_allclose_and_pcc(cpu_momentum_out, momentum_buffer_result, pcc=0.99, rtol=rtol, atol=atol) logger.debug(f"Momentum_out passing (param)={passing}") @@ -207,6 +193,7 @@ def forward(self, x): ) @pytest.mark.parametrize("has_param_out", [True], ids=["HAS_PARAM_OUT_TRUE"]) @pytest.mark.parametrize("fp32_dest_acc_en", fp32_dest_acc_en, ids=fp32_dest_acc_en_ids) +@pytest.mark.parametrize("npu_dtype, cpu_dtype", [[ttnn.bfloat16, torch.bfloat16]], ids=["bfloat16"]) def test_moreh_sgd_callback( shape, lr, @@ -217,6 +204,8 @@ def test_moreh_sgd_callback( momentum_initialized, has_param_out, fp32_dest_acc_en, + npu_dtype, + cpu_dtype, device, use_program_cache, ): @@ -224,17 +213,17 @@ def test_moreh_sgd_callback( pytest.skip() torch.manual_seed(0) - + num_program_cache_entries_list = [] compute_kernel_config = get_compute_kernel_options(fp32_dest_acc_en) # make model and compute grad - x_data = torch.rand(shape).to(torch.bfloat16) - y_data = torch.rand(shape).to(torch.bfloat16) + x_data = torch.rand(shape).to(cpu_dtype) + y_data = torch.rand(shape).to(cpu_dtype) class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() - self.weight = nn.Parameter(torch.randn(shape).to(torch.bfloat16)).to(torch.bfloat16) + self.weight = nn.Parameter(torch.randn(shape).to(cpu_dtype)).to(cpu_dtype) def forward(self, x): return torch.mul(x, self.weight) @@ -275,21 +264,25 @@ def forward(self, x): # create other dev tensors for _ in range(2): - dev_param_in = create_tt_tensor(cpu_param_in, device) - dev_param_out = create_tt_tensor(cpu_param_in, device) + dev_param_in = ttnn.from_torch(cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) + dev_param_out = ttnn.from_torch(cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) - dev_grad = create_tt_tensor(cpu_grad, device) + dev_grad = ttnn.from_torch(cpu_grad, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) dev_momentum_buffer_in = None dev_momentum_buffer_out = None if momentum != 0: if momentum_initialized: if cpu_momentum_in is not None: - dev_momentum_buffer_in = create_tt_tensor(cpu_momentum_in, device) + dev_momentum_buffer_in = ttnn.from_torch( + cpu_momentum_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) else: - dev_momentum_buffer_in = create_tt_tensor(cpu_param_in, device) + dev_momentum_buffer_in = ttnn.from_torch( + cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) - dev_momentum_buffer_out = create_tt_tensor(cpu_param_in, device) + dev_momentum_buffer_out = ttnn.from_torch(cpu_param_in, npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) dev_param_out, dev_momentum_buffer_out = ttnn.operations.moreh.sgd( dev_param_in, dev_grad, @@ -304,10 +297,13 @@ def forward(self, x): momentum_initialized=momentum_initialized, compute_kernel_config=compute_kernel_config, ) + torch_dummy = torch.randn([32, 32]) + tt_dummy = ttnn.from_torch(torch_dummy, device=device) + num_program_cache_entries_list.append(device.num_program_cache_entries()) assert dev_param_in.shape == list(model.weight.shape) # check param_out - param_result = dev_param_out.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + param_result = ttnn.to_torch(dev_param_out).to(cpu_dtype) rtol = atol = 0.05 passing, out = comp_allclose_and_pcc(model.weight, param_result, pcc=0.99, rtol=rtol, atol=atol) @@ -318,10 +314,13 @@ def forward(self, x): # check momentum_out if momentum != 0: - momentum_buffer_result = dev_momentum_buffer_out.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + momentum_buffer_result = ttnn.to_torch(dev_momentum_buffer_out).to(cpu_dtype) passing, out = comp_allclose_and_pcc(cpu_momentum_out, momentum_buffer_result, pcc=0.99, rtol=rtol, atol=atol) logger.debug(f"Momentum_out passing (param)={passing}") logger.debug(f"Momentum_out pcc={out}") assert passing + logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") + assert num_program_cache_entries_list[0] > 0 + assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1] diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp index d366c571117..6bcfa265181 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp @@ -4,8 +4,8 @@ #include "moreh_matmul_device_operation.hpp" -#include "ttnn/operations/moreh/moreh_helper_functions.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/operations/moreh/moreh_helper_functions.hpp" #include "ttnn/tensor/tensor.hpp" namespace ttnn::operations::moreh::moreh_matmul { @@ -23,10 +23,10 @@ void MorehMatmulOperation::validate_inputs( const auto &output = tensor_args.output; // validate tensor - tt::operations::primary::check_tensor(input, "moreh_matmul", "input"); - tt::operations::primary::check_tensor(other, "moreh_matmul", "other"); - tt::operations::primary::check_tensor(output, "moreh_matmul", "output"); - tt::operations::primary::check_tensor(bias, "moreh_matmul", "bias"); + tt::operations::primary::check_tensor(input, "moreh_matmul", "input", {DataType::BFLOAT16}); + tt::operations::primary::check_tensor(other, "moreh_matmul", "other", {DataType::BFLOAT16}); + tt::operations::primary::check_tensor(output, "moreh_matmul", "output", {DataType::BFLOAT16}); + tt::operations::primary::check_tensor(bias, "moreh_matmul", "bias", {DataType::BFLOAT16}); // check matrix dims const auto &input_shape = input.get_shape().value.without_padding(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_device_operation.cpp index 5ac5ced3af2..a9645109a22 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_device_operation.cpp @@ -6,6 +6,7 @@ #include "ttnn/operations/moreh/moreh_helper_functions.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" namespace ttnn::operations::moreh::moreh_sgd { void MorehSgdOperation::validate_inputs( @@ -13,20 +14,22 @@ void MorehSgdOperation::validate_inputs( auto& params_in = tensor_args.param_in; auto& grad = tensor_args.grad; - tt::operations::primary::check_tensor(params_in, "moreh_sgd", "params_in"); - tt::operations::primary::check_tensor(grad, "moreh_sgd", "grad"); + tt::operations::primary::check_tensor(params_in, "moreh_sgd", "params_in", {DataType::BFLOAT16}); + tt::operations::primary::check_tensor(grad, "moreh_sgd", "grad", {DataType::BFLOAT16}); if (tensor_args.momentum_buffer_in) { - tt::operations::primary::check_tensor(*tensor_args.momentum_buffer_in, "moreh_sgd", "momentum_buffer_in"); + tt::operations::primary::check_tensor( + *tensor_args.momentum_buffer_in, "moreh_sgd", "momentum_buffer_in", {DataType::BFLOAT16}); } if (tensor_args.param_out.has_value()) { - tt::operations::primary::check_tensor(tensor_args.param_out.value(), "moreh_sgd", "param_out"); + tt::operations::primary::check_tensor( + tensor_args.param_out.value(), "moreh_sgd", "param_out", {DataType::BFLOAT16}); } if (tensor_args.momentum_buffer_out.has_value()) { tt::operations::primary::check_tensor( - tensor_args.momentum_buffer_out.value(), "moreh_sgd", "momentum_buffer_out"); + tensor_args.momentum_buffer_out.value(), "moreh_sgd", "momentum_buffer_out", {DataType::BFLOAT16}); } }