diff --git a/forge/forge/op/eval/forge/convolution.py b/forge/forge/op/eval/forge/convolution.py index 984b4796..c6179bad 100644 --- a/forge/forge/op/eval/forge/convolution.py +++ b/forge/forge/op/eval/forge/convolution.py @@ -226,7 +226,6 @@ def eval(self, tensors): assert len(tensors) <= 3, "ConvTranspose ops should have up to three inputs (input, weight, bias)" assert len(tensors) >= 2, "ConvTranspose ops should have at least two inputs (input, weight)" t_ops = to_torch_operands(*tensors) - activations = t_ops[0] weights = t_ops[1] bias = t_ops[2] if len(t_ops) == 3 else None @@ -234,49 +233,40 @@ def eval(self, tensors): stride = [self.stride_height, self.stride_width] dilation = [self.dilation_height, self.dilation_width] groups = self.groups - padding = [ - self.padding_left, - self.padding_right, - self.padding_top, - self.padding_bottom, - ] - + assert self.padding_top == self.padding_bottom, "Padding values for top and bottom must be the same." + assert self.padding_left == self.padding_right, "Padding values for left and right must be the same." + padding = (self.padding_top, self.padding_left) channel_last = self.channel_last if channel_last: activations = activations.permute((0, 3, 1, 2)) - padded_activations = torch.nn.functional.pad( - activations, - padding, - ) if t_ops[1].dtype == torch.int8: target_dtype = torch.int32 - padded_activations, weights = padded_activations.float(), weights.float() + activations, weights = activations.float(), weights.float() if bias is not None: bias = bias.float() else: target_dtype = torch.float32 result = torch.nn.functional.conv_transpose2d( - padded_activations, + activations, weights, bias=bias, stride=stride, - padding=0, + padding=padding, dilation=dilation, groups=groups, ) if channel_last: result = result.permute((0, 2, 3, 1)) - result = result.to(target_dtype) return result def shape(self, tensor_shapes): act, weight = tensor_shapes[:2] batch_size = act[0] - cout = weight[1] + cout = weight[1] * self.groups h_in = act[-3] if self.channel_last else act[-2] w_in = act[-2] if self.channel_last else act[-1] @@ -285,22 +275,20 @@ def shape(self, tensor_shapes): output_padding_width = 0 h_out = ( - ((h_in - 1) * self.stride_height) - - (2 * (self.padding_top + self.padding_bottom)) - + (self.dilation_height * (weight[-2] - 1)) + (h_in - 1) * self.stride_height + - (self.padding_top + self.padding_bottom) + + self.dilation_height * (weight[-2] - 1) + output_padding_height + 1 ) w_out = ( - ((w_in - 1) * self.stride_width) - - (2 * (self.padding_left + self.padding_right)) - + (self.dilation_width * (weight[-1] - 1)) + (w_in - 1) * self.stride_width + - (self.padding_left + self.padding_right) + + self.dilation_width * (weight[-1] - 1) + output_padding_width + 1 ) - out_shape = [batch_size, h_out, w_out, cout] if self.channel_last else [batch_size, cout, h_out, w_out] - return out_shape, [] def decompose(self, dc, inputs): diff --git a/forge/forge/tvm_to_python.py b/forge/forge/tvm_to_python.py index e5c28434..aa3c6b08 100644 --- a/forge/forge/tvm_to_python.py +++ b/forge/forge/tvm_to_python.py @@ -697,7 +697,14 @@ def populate_conv2d_transpose_args(graph, nid, compiler_cfg): ) ) - in_channel = next((n["attrs"]["shape"][0][0][0] for n in graph["nodes"] if n["name"] == "model.weight"), None) + in_channel = None + for input_ in node["inputs"]: + input_nid = input_[0] + input_node = graph["nodes"][input_nid] + if input_node["op"] == "parameter" and input_node["name"].endswith("weight"): + in_channel = input_node["attrs"]["shape"][0][0][0] + break + groups = int(node["attrs"]["groups"][0][0]) assert groups == 1 or (in_channel is not None and groups == in_channel), "Only supports group of 1 or in_channel" args.append( diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index dbc9de64..1a1c4b27 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -1307,16 +1307,25 @@ def forward(self, a): @pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") @pytest.mark.parametrize( - "in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode", + "in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode, input_shape", [ - (16, 33, (3, 3), 2, 0, 1, True, 1, "zeros"), - (16, 33, (3, 3), 2, 0, 1, False, 1, "zeros"), - (16, 33, (3, 5), 2, 0, 1, True, 1, "zeros"), + (16, 33, (3, 3), 2, 0, 1, True, 1, "zeros", (16, 50, 100)), + (16, 32, (3, 5), 2, 1, 1, True, 1, "zeros", (16, 50, 100)), + (16, 16, (3, 3), 1, 1, 16, True, 1, "zeros", (16, 50, 100)), + (16, 33, (3, 3), 2, 0, 1, True, 1, "zeros", (20, 16, 50, 100)), + (16, 33, (3, 3), 2, 0, 1, False, 1, "zeros", (20, 16, 50, 100)), + (16, 33, (3, 5), 2, 0, 1, True, 1, "zeros", (20, 16, 50, 100)), + (16, 16, (5, 5), 1, 2, 1, True, 1, "zeros", (20, 16, 50, 100)), + (16, 32, (3, 5), 2, 1, 1, True, 1, "zeros", (20, 16, 50, 100)), + (16, 32, (3, 3), 4, 1, 1, False, 1, "zeros", (20, 16, 50, 100)), + (16, 16, (3, 3), 2, 2, 1, True, 1, "zeros", (20, 16, 50, 100)), + (16, 16, (3, 3), 1, 1, 16, True, 1, "zeros", (20, 16, 50, 100)), ], ) -@pytest.mark.push -def test_convtranspose2d(in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode): - inputs = [torch.randn(20, 16, 50, 100)] +def test_convtranspose2d( + in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode, input_shape +): + inputs = [torch.randn(*input_shape)] framework_model = torch.nn.ConvTranspose2d( in_channels=in_channels, diff --git a/forge/test/model_demos/high_prio/cnn/pytorch/test_monodle.py b/forge/test/model_demos/high_prio/cnn/pytorch/test_monodle.py index 77e2da30..756493fd 100644 --- a/forge/test/model_demos/high_prio/cnn/pytorch/test_monodle.py +++ b/forge/test/model_demos/high_prio/cnn/pytorch/test_monodle.py @@ -13,7 +13,7 @@ def test_monodle_pytorch(test_device): # PyBuda configuration parameters compiler_cfg = forge.config._get_global_compiler_config() - compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE + compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH model_name = "monodle_pytorch"