diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py index a38aa60a770..5e464a75ff1 100644 --- a/models/demos/convnet_mnist/tt/convnet_mnist.py +++ b/models/demos/convnet_mnist/tt/convnet_mnist.py @@ -50,6 +50,8 @@ def convnet_mnist( conv_op_cache={}, debug=True, groups=1, + return_output_size=True, + return_prepared_device_weights=True, ) x = ttnn.relu(x) @@ -93,6 +95,8 @@ def convnet_mnist( conv_op_cache={}, debug=False, groups=1, + return_output_size=True, + return_prepared_device_weights=True, ) x = ttnn.relu(x) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py index a1c0496f211..11d2cbb7fbf 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py @@ -165,7 +165,7 @@ def run_downsample_if_req( shard_layout = ( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ) - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -188,6 +188,7 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -230,12 +231,14 @@ def __call__( weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) act_block_h_override = 0 @@ -296,17 +299,19 @@ def __call__( deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -323,12 +328,13 @@ def __call__( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) if not self.run_downsample_before_conv2: @@ -545,6 +551,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 @@ -851,6 +859,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 3a5c75967e9..52b342f925a 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -160,7 +160,7 @@ def run_downsample_if_req( ): if self.downsample: logger.debug(f"Running downsample") - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -177,9 +177,11 @@ def run_downsample_if_req( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), deallocate_activation=True, reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16), reshard_if_not_optimal=reshard_if_not_optimal, @@ -195,6 +197,7 @@ def run_downsample_if_req( enable_subblock_padding=enable_subblock_padding, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -244,14 +247,16 @@ def __call__( weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) act_block_h_override = 0 @@ -328,9 +333,9 @@ def __call__( deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_acc, @@ -340,6 +345,8 @@ def __call__( enable_subblock_padding=enable_subblock_padding, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) logger.debug( @@ -375,14 +382,16 @@ def __call__( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) if not run_downsample_before_conv2: @@ -731,6 +740,8 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt input_width=self.conv1_input_width, conv_config=self.conv1_config, conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 if self.batch_size == 20: diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py index 5c0750003c1..b58405a2bb3 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py @@ -162,7 +162,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -179,14 +179,17 @@ def run_downsample_if_req( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), deallocate_activation=True, reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -227,12 +230,14 @@ def __call__( weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) act_block_h_override = 0 @@ -291,17 +296,19 @@ def __call__( deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -318,12 +325,13 @@ def __call__( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) if not self.run_downsample_before_conv2: @@ -539,6 +547,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 @@ -842,6 +852,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index b6643d55d4a..64090e65d2f 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -180,14 +180,18 @@ def run_downsample_if_req( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), deallocate_activation=True, reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -228,12 +232,14 @@ def __call__( weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) act_block_h_override = 0 @@ -293,17 +299,19 @@ def __call__( deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -320,12 +328,13 @@ def __call__( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) if not self.run_downsample_before_conv2: @@ -541,6 +550,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 @@ -872,6 +883,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py index 45d93ebf685..967f079fef6 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py @@ -163,7 +163,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -180,15 +180,18 @@ def run_downsample_if_req( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), deallocate_activation=True, reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -234,13 +237,15 @@ def __call__( weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) if is_wormhole_b0(): @@ -342,18 +347,20 @@ def __call__( deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -370,13 +377,14 @@ def __call__( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], math_fidelity=self.model_config["MATH_FIDELITY"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_prepared_device_weights=True, ) if not self.run_downsample_before_conv2: @@ -605,6 +613,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 @@ -938,6 +948,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # Relu is fused with conv1 diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 4cb986c2730..dd23c28eedb 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -127,6 +127,8 @@ def ttnn_vgg16( input_width=conv_ttnn_params[iter_conv_id][3], conv_config=conv_config, conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) ttnn.deallocate(tt_output_tensor_on_device) @@ -249,6 +251,8 @@ def ttnn_vgg11( input_width=conv_ttnn_params_2[iter_conv_id][3], conv_config=conv_config, conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) ttnn.deallocate(tt_output_tensor_on_device) diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py index a2700198f83..907c1ed3555 100644 --- a/models/demos/wormhole/mamba/tt/mamba_conv.py +++ b/models/demos/wormhole/mamba/tt/mamba_conv.py @@ -103,6 +103,8 @@ def __call__(self, input_tensor): conv_op_cache={}, debug=False, groups=self.config.groups // self.config.channels_split_factor, + return_output_length=True, + return_prepared_device_weights=True, ) self.tt_weight_tensor_splits[i] = weights_device output_tensor_splits.append(ttnn.sharded_to_interleaved(tt_output_tensor_on_device)) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 3635026d809..1b7ce90028f 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -104,9 +104,11 @@ def __call__( math_approx_mode_enabled=True, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if self.in_channels < 320 - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if self.in_channels < 320 + else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=True, @@ -129,6 +131,8 @@ def __call__( bias_tensor=self.conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting( # self.device, diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 4e63fc9b13c..caf68ec8bc7 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -432,6 +432,8 @@ def __call__( input_width=self.conv1_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) else: @@ -510,6 +512,8 @@ def __call__( input_width=self.conv1_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) if i != 0: split_hidden_states[i] = ttnn.add( @@ -630,6 +634,8 @@ def __call__( input_width=self.conv2_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut @@ -672,6 +678,8 @@ def __call__( input_width=self.conv_shortcut_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states): diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py index 12e4d543207..14f67b6fccf 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py @@ -264,6 +264,8 @@ def __call__( bias_tensor=self.proj_in_conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) inner_dim = hidden_states.shape[-1] @@ -312,6 +314,8 @@ def __call__( bias_tensor=self.proj_out_conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) if output_bfloat16: diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py index 9cbdfff2f48..b7884373a42 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py @@ -409,6 +409,8 @@ def __call__( input_width=self.input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) sample = ttnn.reallocate(sample) # TODO: Test remove @@ -672,6 +674,8 @@ def __call__( bias_tensor=self.conv_out_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) sample = ttnn.to_memory_config(sample, ttnn.L1_MEMORY_CONFIG) sample = ttnn.clone(sample, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py index 622a63065db..4348af74692 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py @@ -118,5 +118,7 @@ def __call__(self, input, in_channels, out_channels): bias_tensor=self.conv_bias_tensor, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_size=True, + return_prepared_device_weights=True, ) return tt_out diff --git a/models/demos/yolov4/ttnn/common.py b/models/demos/yolov4/ttnn/common.py index 7f7a98d75b5..edddb79f30b 100644 --- a/models/demos/yolov4/ttnn/common.py +++ b/models/demos/yolov4/ttnn/common.py @@ -90,7 +90,7 @@ def __call__(self, device, input_tensor): if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h - [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d( + [output_tensor, self.weights, self.bias] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, bias_tensor=self.bias, @@ -104,5 +104,6 @@ def __call__(self, device, input_tensor): input_height=self.input_params[1], input_width=self.input_params[2], conv_config=conv_config, + return_prepared_device_weights=True, ) return output_tensor diff --git a/models/experimental/functional_segformer/tt/common.py b/models/experimental/functional_segformer/tt/common.py index 85de8856df6..ff908280c42 100644 --- a/models/experimental/functional_segformer/tt/common.py +++ b/models/experimental/functional_segformer/tt/common.py @@ -72,6 +72,8 @@ def __call__(self, device, input_tensor): input_width=input_tensor.shape[2], conv_config=conv_config, groups=self.groups, + return_output_size=True, + return_prepared_device_weights=True, ) ## TODO: Op | WARNING | Tensor at index 0 is not allocated # print("sr2a", output_tensor.shape) diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 3d47538c4e5..6905d824d60 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -194,7 +194,7 @@ def __init__( self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper) def __call__(self, x): - x, _, _, self.weight, self.bias = ttnn.conv2d( + x, self.weight, self.bias = ttnn.conv2d( input_tensor=x, weight_tensor=self.weight, bias_tensor=self.bias, @@ -209,6 +209,7 @@ def __call__(self, x): padding=self.padding, conv_config=self.conv_config, conv_op_cache=self.cache, + return_prepared_device_weights=True, ) return x diff --git a/tests/sweep_framework/sweep_utils/conv2d_common.py b/tests/sweep_framework/sweep_utils/conv2d_common.py index 55769adb984..30b06b9739f 100644 --- a/tests/sweep_framework/sweep_utils/conv2d_common.py +++ b/tests/sweep_framework/sweep_utils/conv2d_common.py @@ -153,6 +153,8 @@ def run_full( input_width=input_width, conv_config=conv_config, groups=groups, + return_output_size=True, + return_prepared_device_weights=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -235,6 +237,8 @@ def run_short( input_height=input_height, input_width=input_width, groups=groups, + return_output_size=True, + return_prepared_device_weights=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index b7341753cc4..9340fdca09b 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -625,7 +625,9 @@ def conv( t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) # t2 = ttnn.experimental.tensor.conv(t0, t1, conv_params, 0, 0, 0, 0, 0, conv_params[0]) - t2 = ttnn.conv2d(t0, t1, conv_params, 0, 0, 0, 0, 0, conv_params[0]) + t2 = ttnn.conv2d( + t0, t1, conv_params, 0, 0, 0, 0, 0, conv_params[0], return_output_size=True, return_prepared_device_weights=True + ) return tt2torch_tensor(t2) diff --git a/tests/ttnn/unit_tests/operations/test_conv1d.py b/tests/ttnn/unit_tests/operations/test_conv1d.py index 3e7a1496c63..e169404e78d 100644 --- a/tests/ttnn/unit_tests/operations/test_conv1d.py +++ b/tests/ttnn/unit_tests/operations/test_conv1d.py @@ -120,6 +120,8 @@ def run_conv( conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, + return_output_length=True, + return_prepared_device_weights=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 1a62fd50b50..0e08ee95e02 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -168,6 +168,8 @@ def run_conv( debug=debug, groups=groups, memory_config=memory_config, + return_output_size=True, + return_prepared_device_weights=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -306,6 +308,8 @@ def run_conv_with_split( input_width=input_width, conv_config=conv_config, conv_op_cache=reader_patterns_cache, + return_output_size=True, + return_prepared_device_weights=True, ) tt_conv_output_tensor = ttnn.from_device(tt_output_tensor_on_device) torch_conv_output_tensor = ttnn.to_torch(tt_conv_output_tensor) @@ -555,6 +559,8 @@ def test_conv_ws( conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, + return_output_size=True, + return_prepared_device_weights=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py b/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py new file mode 100644 index 00000000000..06e115f5783 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +from models.utility_functions import ( + is_wormhole_b0, + skip_for_grayskull, + is_grayskull, + is_wormhole_b0, + is_x2_harvested, + is_blackhole, + skip_for_blackhole, + is_blackhole, +) +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc, check_with_pcc_without_tensor_printout +import ttnn + + +@skip_for_grayskull() +@skip_for_blackhole() +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + ( + # unique convs in rn50 (complete list) + # first conv post folding and input_channels padding to tile width + # (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, None), HANGS!! + (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 256}), + # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 32}), Out of Memory!! + # rn50 layer1 + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + # rn50 layer2 + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), + (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), + (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, {"act_block_h": 32}), + (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + # rn50 layer3 + (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), + (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), + (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), + (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), + (16, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), + (20, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), + # rn50 layer4 + (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), + (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), + (20, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), + (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + (16, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + (20, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + ## small test + (1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 2, "grid_size": (2, 2)}), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 4, "grid_size": (2, 4)}), + # (1, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, False, None), sliding_window_op_infra/sliding_window.cpp:341: indices_length_last_core <= indices_length_per_core + (8, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + # r50 1x1s2 shapes + # Fails with packer_l1_acc = True (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 first bottleneck downsample shape + (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, True, None), # r50 first bottleneck downsample shape + # Fails with packer_l1_acc = True (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 second bottleneck downsample shape + # (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit + (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, False, None), # r50 third bottleneck downsample shape + # (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit + (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, False, None), # r50 fourth bottleneck downsample shape + # (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit + # (20, 128, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), ## L2M1 DS: doesn't fit + ), +) +@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) +@pytest.mark.parametrize("has_bias", [True, False], ids=["has_bias", "no_bias"]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 2**15}], indirect=True) +def test_prepare_conv_weights( + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + packer_l1_acc, + config_override, + has_bias, + device, +): + if device.core_grid.y == 7: + pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") + + inp_shape = (batch_size, input_channels, input_height, input_width) + conv_weight_shape = (output_channels, input_channels, filter_height, filter_width) + torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16) + torch_input_tensor = torch.randn(inp_shape, dtype=torch.bfloat16) + torch_bias_tensor = torch.randn((1, 1, 1, output_channels), dtype=torch.bfloat16) if has_bias else None + + torch_out_golden_tensor = torch.nn.functional.conv2d( + torch_input_tensor, + torch_weight_tensor, + bias=torch_bias_tensor.reshape(-1) if has_bias else None, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(1, 1), + groups=1, + ).permute(0, 2, 3, 1) + + tt_input_tensor = ttnn.from_torch(torch_input_tensor.transpose(-3, -2).transpose(-2, -1), ttnn.bfloat16) + tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16) + tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16) if has_bias else None + + conv_config = ttnn.Conv2dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat16, + input_channels_alignment=(16 if input_channels == 16 and input_height == 115 else 32), + packer_l1_accum_enabled=packer_l1_acc, + enable_act_double_buffer=False, + enable_split_reader=False, + enable_subblock_padding=False, + ) + + if config_override and "act_block_h" in config_override: + conv_config.act_block_h_override = config_override["act_block_h"] + + if config_override and "act_block_w_div" in config_override: + conv_config.act_block_w_div = config_override["act_block_w_div"] + + if config_override and "num_cores_nhw" in config_override: + if config_override["num_cores_nhw"] == 98: + conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (11, 7)), ttnn.CoreRange((0, 8), (1, 8))}) + conv_config.override_sharding_config = True + print("Setting num_cores_nhw to 98") + + conv_kwargs = { + "in_channels": input_channels, + "out_channels": output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (filter_height, filter_width), + "stride": (stride_h, stride_w), + "padding": (pad_h, pad_w), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + tt_input_tensor = ttnn.to_device(tt_input_tensor, device) + tt_weight_tensor_formatted = ttnn.prepare_conv_weights( + weight_tensor=tt_weight_tensor, + weights_format="OIHW", + input_memory_config=tt_input_tensor.memory_config(), + **conv_kwargs, + ) + tt_bias_tensor_formatted = ( + ttnn.prepare_conv_bias( + bias_tensor=tt_bias_tensor, input_memory_config=tt_input_tensor.memory_config(), **conv_kwargs + ) + if has_bias + else None + ) + + tt_weight_tensor_formatted = ttnn.to_device(tt_weight_tensor_formatted, device) + tt_bias_tensor_formatted = ttnn.to_device(tt_bias_tensor_formatted, device) if has_bias else None + + tt_output_tensor_on_device = ttnn.conv2d( + input_tensor=tt_input_tensor, + weight_tensor=tt_weight_tensor_formatted, + bias_tensor=tt_bias_tensor_formatted, + **conv_kwargs, + ) + + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + + torch_output_tensor = torch_output_tensor[:, :, :, :output_channels] + torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape) + # + pcc = 0.99 + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc) + logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}") + assert passing diff --git a/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py b/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py index 84ee4d5d972..d80d2d42a2c 100644 --- a/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py +++ b/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py @@ -83,15 +83,15 @@ def __init__(self, parameters, downsample, model_config) -> None: if downsample: self.ds_conv_weight_tensor = ttnn.from_torch( parameters.ds_conv.weight, - dtype=model_config["WEIGHTS_DTYPE"] - if model_config["WEIGHTS_DTYPE"] != ttnn.bfloat8_b - else ttnn.float32, + dtype=( + model_config["WEIGHTS_DTYPE"] if model_config["WEIGHTS_DTYPE"] != ttnn.bfloat8_b else ttnn.float32 + ), ) self.ds_conv_bias_tensor = ttnn.from_torch( parameters.ds_conv.bias, - dtype=model_config["WEIGHTS_DTYPE"] - if model_config["WEIGHTS_DTYPE"] != ttnn.bfloat8_b - else ttnn.float32, + dtype=( + model_config["WEIGHTS_DTYPE"] if model_config["WEIGHTS_DTYPE"] != ttnn.bfloat8_b else ttnn.float32 + ), ) self.ds_conv_input_channels = self.ds_conv_weight_tensor.shape[1] self.ds_conv_output_channels = self.ds_conv_weight_tensor.shape[0] @@ -121,6 +121,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( @@ -143,6 +145,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac activation="relu", ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) if self.downsample: @@ -165,6 +169,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) ttnn.deallocate(x) else: @@ -191,6 +197,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac activation="relu", ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # conv3 is 1x1 conv @@ -214,6 +222,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_size=True, + return_prepared_device_weights=True, ) # underscore version is in_place = True diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index fd207c920a0..d0d8ad59589 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -38,12 +38,14 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/barrier/barrier.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/barrier/barrier_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/core/core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index c5685ac57b4..12cf212dc85 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #include "conv2d.hpp" +#include "conv2d_utils.hpp" +#include "prepare_conv2d_weights.hpp" #include #include #include @@ -26,765 +28,9 @@ using sliding_window::ParallelConfig; namespace conv2d { -uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { - uint32_t divisor = start_divisor; - while (num % divisor != 0) divisor = divisor - 1; - return divisor; -} - -uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { - uint32_t divisor = start_divisor; - uint32_t padded_num = round_up(num, divisor); - while ((padded_num - num) >= (int)(padded_num / divisor)) { - divisor = divisor - 1; - padded_num = round_up(num, divisor); - } - return divisor; -} - -uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { - uint32_t divisor = start_divisor; - while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; - return divisor; -} - -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - Tensor conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype){ - return tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype); - } - -// Converts convolution weights to tilized 2d matrix layout with special block height padding -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - Tensor conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype){ - return tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype); - } - -// Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype){ - return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(conv_weight_tensor, num_groups, output_dtype); -} - -ParallelConfig determine_parallel_config( - const TensorMemoryLayout shard_layout, - uint32_t batch_size, - uint32_t input_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t output_channels, - const CoreCoord& compute_grid_size, - ShardOrientation block_shard_orientation, - bool is_out_tiled) { - - uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; - uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1; - uint32_t out_nhw_ntiles = tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; - uint32_t out_c_ntiles = tt::round_up(output_channels, effective_tile_width) / effective_tile_width; - - // calculate num_core_nhw and the grid - uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; - uint32_t num_cores_nhw = 0; - CoreRangeSet grid; - if (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - num_cores_nhw = find_closest_largest_divisor(out_nhw_ntiles, max_num_cores); - if (num_cores_nhw < compute_grid_size.x && out_nhw_ntiles > compute_grid_size.x) { - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, compute_grid_size.x); - } - grid = num_cores_to_corerangeset(num_cores_nhw, compute_grid_size, true); - } else if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { - uint32_t start_divisor = - block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); - uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x); - uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; - uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; - CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); - grid = CoreRangeSet({core_range}); - } else if (shard_layout == TensorMemoryLayout::WIDTH_SHARDED) { - num_cores_nhw = 1; - uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), max_num_cores); - grid = num_cores_to_corerangeset(num_cores_c, compute_grid_size, true); - } else { - TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); - } - - auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation : ShardOrientation::ROW_MAJOR; // NOTE: taking ROW_MAJOR as default orientation for HEIGHT_SHARDED and WIDTH_SHARDED - ParallelConfig pconfig = { - .grid = grid, - .shard_scheme = shard_layout, - .shard_orientation = shard_orientation }; - - return pconfig; -} - -uint32_t get_num_cores_nhw_from_parallel_config(const ParallelConfig& pconfig) { - TT_ASSERT(!pconfig.grid.ranges().empty()); - TT_ASSERT( - pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED || - pconfig.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED || - pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED); - auto grid_size = pconfig.grid.bounding_box().grid_size(); - uint32_t num_cores = pconfig.grid.num_cores(); - uint32_t num_cores_nhw = 0; - if(pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { - return 1; - } - - if (pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { - num_cores_nhw = num_cores; - } else if (pconfig.shard_orientation == ShardOrientation::COL_MAJOR) { - num_cores_nhw = grid_size.x; - } else { - TT_ASSERT(pconfig.shard_orientation == ShardOrientation::ROW_MAJOR); - num_cores_nhw = grid_size.y; - } - - TT_ASSERT(num_cores_nhw > 0); - return num_cores_nhw; -} - -uint32_t get_num_cores_channels_from_parallel_config(const ParallelConfig& pconfig) { - TT_ASSERT(!pconfig.grid.ranges().empty()); - TT_ASSERT( - pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED || - pconfig.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED || - pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED); - auto grid_size = pconfig.grid.bounding_box().grid_size(); - uint32_t num_cores_channels = 0; - if (pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { - num_cores_channels = 1; - } else if(pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { - num_cores_channels = pconfig.grid.num_cores(); - } else if (pconfig.shard_orientation == ShardOrientation::COL_MAJOR) { - num_cores_channels = grid_size.y; - } else { - TT_ASSERT(pconfig.shard_orientation == ShardOrientation::ROW_MAJOR); - num_cores_channels = grid_size.x; - } - TT_ASSERT(num_cores_channels > 0); - return num_cores_channels; -} - -MemoryConfig create_sharded_memory_config_from_parallel_config( - const ttnn::Shape& tensor_shape, ParallelConfig& parallel_config, uint32_t tile_size) { - - log_debug(tt::LogOp, "create_sharded_memory_config_from_parallel_config: tensor_shape: {}, parallel_config: {}, tile_size: {}", tensor_shape, parallel_config, tile_size); - // tensor_shape is [N, H, W, C] - TT_ASSERT(tensor_shape[0] == 1 && tensor_shape[1] == 1); // todo: add support for generic non-2d shapes - // uint32_t channels = tensor_shape[3]; - uint32_t channels = tensor_shape.with_tile_padding()[3]; - uint32_t num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); - uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); - auto shard_scheme = parallel_config.shard_scheme; - auto shard_orientation = parallel_config.shard_orientation; - - uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2]; - uint32_t nhw_padded = nhw_shape; - if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { - nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); - } - uint32_t nhw_shard = nhw_padded / num_cores_nhw; - TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels); - uint32_t channel_shard = channels / num_cores_channels; - auto shard_spec = ShardSpec{parallel_config.grid, {nhw_shard, channel_shard}, shard_orientation}; - return MemoryConfig{shard_scheme, BufferType::L1, shard_spec}; -} - - -OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( - const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c) { - TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); - const auto& shard_spec = conv_output_mem_config.shard_spec.value(); - const auto& shard_shape = shard_spec.shape; - TT_ASSERT(shard_shape[1] % 32 == 0); - uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32); - return { - .grid_size = shard_spec.grid.bounding_box().grid_size(), - .num_cores_nhw = num_cores_nhw, - .num_cores_c = num_cores_c, - .per_core_out_matrix_height_ntiles = per_core_out_matrix_height_ntiles, - .per_core_out_matrix_width_ntiles = shard_shape[1] / 32, - .per_core_out_matrix_height = shard_shape[0], - .per_core_out_matrix_width = shard_shape[1], - }; -} - -std::pair determine_largest_subblock_size( - uint32_t block_height, uint32_t block_width, bool fp32_accum, bool split_reader_enabled) { - constexpr std::array, 20> subblocks = {{ - {2, 4}, {4, 2}, {1, 8}, {8, 1}, {1, 7}, {7, 1}, {2, 3}, {3, 2}, {1, 6}, {6, 1}, - {1, 5}, {5, 1}, {2, 2}, {1, 4}, {4, 1}, {1, 3}, {3, 1}, {1, 2}, {2, 1}, {1, 1}, - }}; - - uint32_t subblock_h = 0; - uint32_t subblock_w = 0; - for (auto [subblock_height, subblock_width] : subblocks) { - if (fp32_accum && (subblock_height * subblock_width > 4)) { - continue; - } - - if (split_reader_enabled && (block_height / subblock_height) < 2) { - continue; - } - - if ((block_height % subblock_height == 0) && (block_width % subblock_width == 0)) { - if (subblock_width != block_width && subblock_height != 1) { - continue; - } - subblock_h = subblock_height; - subblock_w = subblock_width; - break; - } - } - TT_ASSERT(subblock_h > 0 && subblock_w > 0); - return {subblock_h, subblock_w}; -} - -OptimizedConvBlockConfig determine_per_core_conv_block_config( - const ParallelConfig& parallel_config, - const OptimizedConvParallelizationConfig& conv_op_parallel_config, - uint32_t padded_in_channels, - uint32_t act_block_h_override, - uint32_t act_block_w_div, - uint32_t window_h, - uint32_t window_w, - bool fp32_accum, - bool split_reader_enabled) { - - if (act_block_h_override > 0) { - TT_ASSERT( - act_block_h_override % 32 == 0, - "Config Error: act_block_h_override must be a multiple of 32 (tile height)."); - } - auto grid_size = parallel_config.grid.bounding_box().grid_size(); - uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; - if (parallel_config.shard_scheme != TensorMemoryLayout::WIDTH_SHARDED && act_block_h_override > 0 ) { - log_debug(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); - act_block_h_ntiles = act_block_h_override / constants::TILE_HEIGHT; - } - uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED - ? round_up(padded_in_channels * window_w, 32) - : padded_in_channels; - if(parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { - act_block_w = (padded_in_channels * window_h * window_w)/(parallel_config.grid.num_cores() * act_block_w_div); - } - TT_ASSERT(act_block_w % 32 == 0); - uint32_t act_block_w_ntiles = act_block_w / 32; - uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 - : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y - : grid_size.x; - uint32_t out_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; - uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntiles; - //act_block_h_ntiles / block_config.out_subblock_h_ntiles) >= 2 - auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = - determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled); - return { - .act_block_h_ntiles = act_block_h_ntiles, - .act_block_w_ntiles = act_block_w_ntiles, - .out_subblock_h_ntiles = out_subblock_h_ntiles, - .out_subblock_w_ntiles = out_subblock_w_ntiles}; -} - -bool use_matmul_for_1x1_conv( - const std::array& kernel_size, - const std::array& stride, - const std::array& padding, - const std::array& dilation, - uint32_t groups) { - return kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] && stride[0] == 1 && padding[0] == 0 && - padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 && groups == 1; -} - -// Implements a heuristic for selecting shard layout based on how many tenix cores are available -// for each shard. -static TensorMemoryLayout select_shard_spec( - bool is_mm_conv, - uint32_t batch_size, - uint32_t in_channels, - uint32_t out_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t weights_width, - uint32_t input_width, - ShardOrientation shard_orientation, - const CoreCoord& compute_grid_size) { - auto get_core_count_for_sharding = [&](TensorMemoryLayout shard_layout) { - return determine_parallel_config( - shard_layout, - batch_size, - in_channels, - output_height, - output_width, - out_channels, - compute_grid_size, - shard_orientation) - .grid.num_cores(); - }; - - // 1d convs support only height sharding - const bool is_conv1d = weights_width == 1 && input_width == 1; - - const uint32_t cc_height = get_core_count_for_sharding(TensorMemoryLayout::HEIGHT_SHARDED); - // matmul doesn't support width sharding - const uint32_t cc_width = - !is_mm_conv && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::WIDTH_SHARDED) : 0; - const uint32_t cc_block = !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::BLOCK_SHARDED) : 0; - - uint32_t max_cc = cc_block; - TensorMemoryLayout shard_layout = TensorMemoryLayout::BLOCK_SHARDED; - - // Prefer block sharding over height sharding but make sure that we got at least - // some blocking on width dimension as well. - if (cc_height > max_cc || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { - shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; - max_cc = cc_height; - } - - if (cc_width >= max_cc) { - shard_layout = TensorMemoryLayout::WIDTH_SHARDED; - max_cc = cc_width; - } - - if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { - // For large number of input channels prefer width sharding - // even if it has less cores. - // For BH we probably need to adjust this, or even better we make block sharding - // more configurable rearding l1 memory usage for weights. - if (cc_width >= 40 && in_channels > 1280) { - shard_layout = TensorMemoryLayout::WIDTH_SHARDED; - log_debug(LogOp, "Switching to WIDTH_SHARDED layout due to large in_channels"); - max_cc = cc_width; - } - } - log_debug(LogOp, "Selected shard layout: {}, num cores: {}", shard_layout, max_cc); - - return shard_layout; -} - -template -std::tuple get_conv_padded_input_shape_and_mem_config( - T* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) { - ttnn::Tensor input_tensor = input_tensor_; // tensor to return - bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); - bool needs_shard_or_reshard = false; - if (conv_config.override_sharding_config && conv_config.reshard_if_not_optimal) { - TT_ASSERT( - false, - "Incorrect config provided: reshard_if_not_optimal and override_sharding_config cannot both be set."); - } - - TT_FATAL( - (!input_tensor_on_device || input_tensor_.is_sharded()) || conv_config.shard_layout.has_value(), - "Tesor must be sharded or shard_layout must be set."); - - TensorMemoryLayout shard_layout; - if (conv_config.shard_layout.has_value()) { - shard_layout = conv_config.shard_layout.value(); - } - - ParallelConfig input_tensor_parallel_config; - if (!input_tensor_on_device) { - needs_shard_or_reshard = true; - } else { - const auto& input_memory_config = input_tensor_.memory_config(); - if (!input_memory_config.is_sharded()) { - needs_shard_or_reshard = true; - } else { - const auto input_shard_scheme = input_memory_config.memory_layout; - const auto input_shard_orientation = input_memory_config.shard_spec.value().orientation; - const auto input_shard_grid = input_memory_config.shard_spec.value().grid; - ParallelConfig pconfig = { - .grid = input_shard_grid, - .shard_scheme = input_shard_scheme, - .shard_orientation = input_shard_orientation}; - input_tensor_parallel_config = pconfig; - if (input_shard_scheme != TensorMemoryLayout::BLOCK_SHARDED && - input_shard_orientation != ShardOrientation::ROW_MAJOR) { - needs_shard_or_reshard = true; - } - if (input_shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED && - input_shard_scheme != TensorMemoryLayout::BLOCK_SHARDED && - input_shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { - needs_shard_or_reshard = true; - } - if (conv_config.override_sharding_config) { - TT_FATAL(conv_config.core_grid.has_value(), "If override_sharding_config is set, core_grid must be set as well."); - TT_FATAL(conv_config.shard_layout.has_value(), "If override_sharding_config is set, shard_layout must be set as well."); - if (conv_config.core_grid.value() != input_shard_grid) { - needs_shard_or_reshard = true; - } - if(shard_layout!=input_shard_scheme) { - needs_shard_or_reshard = true; - } - bool input_transpose_shards = input_shard_orientation == ShardOrientation::COL_MAJOR; - if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED && conv_config.transpose_shards != input_transpose_shards) { - needs_shard_or_reshard = true; - } - } - } - } - - // shallow conv variriant not supported - // out_channels <= 256 incorrect output from pack_untilize_dst if output > 256 Tracking --> #14236 - // bf8 not supported due to limation of sharding dim multipl of 32 - bool use_non_tile_height = (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) && out_channels <= 256 && conv_config.act_block_h_override == 0 && - (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR && conv_config.input_channels_alignment != 16; //shalow conv varient - - ParallelConfig parallel_config = input_tensor_parallel_config; - if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { - auto block_shard_orientation = - conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - ParallelConfig optimal_parallel_config = determine_parallel_config( - shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height); - - if (conv_config.override_sharding_config) { - TT_FATAL(conv_config.core_grid.has_value(), "Error"); - // override parallel config - auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED - ? block_shard_orientation - : ShardOrientation::ROW_MAJOR; - parallel_config = { - .grid = conv_config.core_grid.value(), - .shard_scheme = shard_layout, - .shard_orientation = shard_orientation}; - } else { - parallel_config = optimal_parallel_config; - } - if (input_tensor_parallel_config != parallel_config) { - needs_shard_or_reshard = true; - } - } - if (needs_shard_or_reshard) { - uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); - // TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape()); - uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; - uint32_t round_up_size = (use_non_tile_height || conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED) ? 1 : tt::constants::TILE_HEIGHT; - uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); - TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); - uint32_t tensor_width = input_tensor.get_shape()[3]; - uint32_t input_tensor_width_snapped_to_channels_alignment = - tt::round_up(tensor_width, conv_config.input_channels_alignment); - TT_ASSERT(input_tensor_width_snapped_to_channels_alignment >= tensor_width); - - auto input_padded_shape = ttnn::Shape(std::array{ - 1, - 1, - input_tensor_height_snapped_to_tile, - input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and - // tt::tt_metal::LegacyShape issue to clean up next line - auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{ - input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), - parallel_config, round_up_size); - return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height}; - } else { - return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard, use_non_tile_height}; - } -} - -template -std::tuple shard_or_reshard_tensor_if_required( - T* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels, - bool is_mm_conv) { - ttnn::Tensor input_tensor = input_tensor_; // tensor to return - bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); - - auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height] = - get_conv_padded_input_shape_and_mem_config( - device, - input_tensor_, - conv_config, - batch_size, - height, - width, - in_channels, - out_channels); - ParallelConfig parallel_config = { - .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, - .shard_scheme = input_tensor_sharded_memory_config.memory_layout, - .shard_orientation = input_tensor_sharded_memory_config.shard_spec.value().orientation - }; - if (needs_shard_or_reshard) { - if (input_tensor.get_shape()[0] != 1 or input_tensor.get_shape()[1] != 1) { - // reshape to [1, 1, N*H*W, C] - input_tensor = ttnn::reshape( - input_tensor, - ttnn::SimpleShape(std::array{ - 1, - 1, - input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2], - input_tensor.get_shape()[3]})); - } - - uint32_t tensor_height = input_tensor.get_shape()[2]; - uint32_t tensor_width = input_tensor.get_shape()[3]; - - if (!input_tensor_on_device) { - if (input_padded_shape[-2] != tensor_height || input_padded_shape[-1] != tensor_width) { - input_tensor = ttnn::pad( - input_tensor, - tt::tt_metal::Array4D({input_tensor.get_shape()[0], - input_tensor.get_shape()[1], - input_padded_shape[-2], - input_padded_shape[-1]}), - tt::tt_metal::Array4D({0, 0, 0, 0}), - 0); - } - } - - if (input_tensor_on_device) { - if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && - parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { - // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout - input_tensor = - ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); - } - auto resharded_input_tensor = ttnn::to_memory_config( - input_tensor, input_tensor_sharded_memory_config, std::nullopt); - if (conv_config.deallocate_activation) { - input_tensor.deallocate(); - resharded_input_tensor = ttnn::operations::core::reallocate(resharded_input_tensor, resharded_input_tensor.memory_config()); - } - input_tensor = resharded_input_tensor; - } else { - if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && - parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { - // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout - input_tensor = ttnn::to_device(input_tensor, device, std::nullopt); - input_tensor = - ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); - input_tensor = ttnn::to_memory_config(input_tensor, input_tensor_sharded_memory_config, std::nullopt); - } else { - input_tensor = ttnn::to_device(input_tensor, device, input_tensor_sharded_memory_config); - } - } - } - return {input_tensor, parallel_config, needs_shard_or_reshard, use_non_tile_height}; -} - -void validate_weight_and_bias_tensors( - const ttnn::Tensor& weight_tensor, std::optional& bias_tensor) { - TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE)); - TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR); - TT_ASSERT(weight_tensor.get_shape().rank() == 4); - // TODO: enable this assert - // TT_ASSERT(weight_tensor.get_shape() == weight_tensor.get_legacy_shape()); - if (bias_tensor.has_value()) { - TT_ASSERT(!ttnn::has_storage_type_of(bias_tensor.value(), ttnn::DEVICE_STORAGE_TYPE)); - TT_ASSERT(bias_tensor.value().get_shape().rank() == 4); - TT_ASSERT(bias_tensor.value().get_layout() == Layout::ROW_MAJOR); - // TODO: enable this assert - // TT_ASSERT(bias_tensor.value().get_shape() == bias_tensor.value().get_legacy_shape()); - } -} - -template -std::pair> prepare_conv_weights_biases_and_move_to_device( - const ttnn::Tensor& weight_tensor, - std::optional& bias_tensor, - uint32_t input_channels_alignment, - DataType weights_bias_dtype, - uint32_t weight_block_h_ntiles, - uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, - T * device, - uint32_t groups, - uint32_t act_block_h_ntiles, - uint32_t input_width) { - - validate_weight_and_bias_tensors(weight_tensor, bias_tensor); - ttnn::Tensor weight_tensor_; // tensor to return - ttnn::Tensor bias_tensor_; - - auto original_weights_shape = weight_tensor.get_shape(); - uint32_t original_weights_out_channels = original_weights_shape[0]; - uint32_t original_weights_in_channels = original_weights_shape[1]; - uint32_t original_weights_window_h = original_weights_shape[2]; - uint32_t original_weights_window_w = original_weights_shape[3]; - - bool is_conv1d = original_weights_window_w == 1 && input_width == 1; - bool is_depthwise_conv = groups == original_weights_out_channels && original_weights_in_channels == 1; - - weight_tensor_ = weight_tensor; - - // Convert weight tensor to 0 padded shape if groups > 1 - if (!is_conv1d and groups > 1) { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); - } - else if (is_conv1d and groups > 1) { - if (is_depthwise_conv) { - weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype); - weight_block_h_ntiles = act_block_h_ntiles; - } - else{ - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); - } - } - - auto weights_shape = weight_tensor_.get_shape(); - uint32_t out_channels = weights_shape[0]; - uint32_t in_channels = weights_shape[1]; - uint32_t window_h = weights_shape[2]; - uint32_t window_w = weights_shape[3]; - uint32_t out_channel_padding = tt::round_up(out_channels, 32) - out_channels; - tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( - {tt::round_up(out_channels, 32), tt::round_up(in_channels, input_channels_alignment), window_h, window_w})); - if (weights_bias_dtype == DataType::BFLOAT8_B) { - TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); - if (bias_tensor.has_value()) { - TT_ASSERT(bias_tensor.value().get_dtype() == DataType::FLOAT32); - } - } else { - // TODO: fix the need to check this. We should be able to accept any datatype and convert - TT_ASSERT(weight_tensor_.get_dtype() == weights_bias_dtype); - if (bias_tensor.has_value()) { - TT_ASSERT(bias_tensor.value().get_dtype() == weights_bias_dtype); - } - } - weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); - - // for conv op, pad the weights to block shape - if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( - weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); - } else { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( - weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); - } - - uint32_t weight_matrix_height = in_channels * window_h * window_w; - int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height; - TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative"); - - // convert_conv_weight_tensor adds the padding to the base shape. - // Reshape the weights to remove padding from the base shape. - weight_tensor_.set_shape( - ttnn::Shape(std::array{1, 1, weight_matrix_height, out_channels}, - std::array, 4>{ - std::array{0, 0}, - std::array{0, 0}, - std::array{0, weight_matrix_height_padding}, - std::array{0, out_channel_padding} - })); - - weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt); - if (bias_tensor.has_value()) { - bias_tensor_ = bias_tensor.value(); - auto bias_shape = bias_tensor_.get_shape(); - TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1); - tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( - std::array({1, 1, 32, tt::round_up(out_channels, weight_block_w_ntiles * 32)})); - bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); - bias_tensor_ = ttnn::to_layout( - bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); - if (bias_tensor_.get_dtype() != weights_bias_dtype) { - bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype); - } - bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); - } - - return {weight_tensor_, bias_tensor.has_value() ? bias_tensor_ : std::optional()}; -} - -ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_conv_op_config( - OptimizedConvParallelizationConfig conv_parallelization_config, - OptimizedConvBlockConfig conv_blocking_config, - bool height_sharded, - string activation, - bool transpose_mcast, - uint32_t grid_size_along_c) { - if (height_sharded) { - ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig matmul_config = { - .compute_with_storage_grid_size = conv_parallelization_config.grid_size, - .in0_block_w = conv_blocking_config.act_block_w_ntiles, - .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, - .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, - .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntiles, - .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntiles, - .fuse_batch = true, - .mcast_in0 = false}; - if (activation != "") { - matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation); - } - return matmul_config; - } else { - TT_ASSERT(conv_blocking_config.act_block_w_ntiles % grid_size_along_c == 0); - ttnn::operations::matmul::MatmulMultiCoreReuseMultiCastProgramConfig matmul_config = { - .compute_with_storage_grid_size = conv_parallelization_config.grid_size, - .in0_block_w = conv_blocking_config.act_block_w_ntiles / grid_size_along_c, - .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, - .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, - .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntiles, - .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntiles, - .transpose_mcast = transpose_mcast}; - if (activation != "") { - matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation); - } - return matmul_config; - } -} - -static void adjust_conv_op_config_for_auto_shard( - bool is_mm_conv, - uint32_t batch_size, - uint32_t in_channels, - uint32_t out_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t weights_width, - uint32_t input_width, - const CoreCoord& compute_grid_size, - Conv2dConfig& conv_config) { - ShardOrientation shard_orientation = - conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - conv_config.shard_layout = select_shard_spec( - is_mm_conv, - batch_size, - in_channels, - out_channels, - output_height, - output_width, - weights_width, - input_width, - shard_orientation, - compute_grid_size); - - if (conv_config.act_block_h_override == 0 && conv_config.shard_layout != TensorMemoryLayout::WIDTH_SHARDED) { - if (in_channels <= constants::TILE_WIDTH / 2 && conv_config.input_channels_alignment == constants::TILE_WIDTH && - !is_mm_conv && conv_config.shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - log_debug(LogOp, "Auto shard, enable shallow conv"); - // height sharded, non matmul conv, with input channels <= 16, and default setting for - // input_channels_alignment - conv_config.input_channels_alignment = constants::TILE_WIDTH / 2; - } - - // Set act_block_h_override to min value to - // be conservative with L1 memory usage. - conv_config.act_block_h_override = constants::TILE_HEIGHT; - } -} +using OutputHeight = uint32_t; +using OutputWidth = uint32_t; +using Result = std::tuple>; template Result conv2d( @@ -810,20 +56,21 @@ Result conv2d( ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { - // In this case we deduce the shard layout. - adjust_conv_op_config_for_auto_shard( - mm_conv, - batch_size, - in_channels, - out_channels, - output_height, - output_width, - weight_tensor.get_shape()[3], - input_width, - device->compute_with_storage_grid_size(), - conv_config); - } + // In this case we deduce the shard layout. + adjust_conv_op_config_for_auto_shard_if_necessary( + mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + weight_tensor.get_shape()[3], + input_width, + kernel_size, + stride, + device->compute_with_storage_grid_size(), + conv_config, + ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt); auto [input_tensor_post_tm, parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv); @@ -858,18 +105,45 @@ Result conv2d( std::optional bias_tensor_on_device = bias_tensor; if (!weight_is_on_device) { // prepare weights in desired layout and move to device - tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device( + weight_tensor_on_device = prepare_conv_weights( weight_tensor, - bias_tensor, - conv_config.input_channels_alignment, - conv_config.weights_dtype, - opt_conv_op_block_config.act_block_w_ntiles, - opt_conv_op_block_config.out_subblock_w_ntiles, - parallel_config, - device, + input_tensor_post_tm.memory_config(), + "OIHW", + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, groups, - opt_conv_op_block_config.act_block_h_ntiles, - input_width); + device, + conv_config + ); + weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt); + + } + + if (bias_tensor.has_value() && !ttnn::is_tensor_on_device_or_multidevice(bias_tensor.value())) { + bias_tensor_on_device = prepare_conv_bias( + bias_tensor.value(), + input_tensor_post_tm.memory_config(), + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + device, + conv_config + ); + bias_tensor_on_device = ttnn::operations::core::to_device(bias_tensor_on_device.value(), device, std::nullopt); } // if 1x1 conv w/ stride 1, convert input tensor to tile layout if required Tensor input_tensor_post_tm_out; @@ -1028,75 +302,51 @@ Result conv2d( } } -template std::tuple get_conv_padded_input_shape_and_mem_config( - Device* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - -template std::tuple get_conv_padded_input_shape_and_mem_config( - MeshDevice * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - -template std::tuple shard_or_reshard_tensor_if_required( - Device* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels, - bool is_mm_conv); - -template std::tuple shard_or_reshard_tensor_if_required( - MeshDevice * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channel, - bool is_mm_conv); - -template std::pair> prepare_conv_weights_biases_and_move_to_device( - const ttnn::Tensor& weight_tensor, - std::optional& bias_tensor, - uint32_t input_channels_alignment, - DataType weights_bias_dtype, - uint32_t weight_block_h_ntiles, - uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, - Device * device, - uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width); - -template std::pair> prepare_conv_weights_biases_and_move_to_device( - const ttnn::Tensor& weight_tensor, - std::optional& bias_tensor, - uint32_t input_channels_alignment, - DataType weights_bias_dtype, - uint32_t weight_block_h_ntiles, - uint32_t weight_block_w_ntiles, - const ParallelConfig& parallel_config, - MeshDevice * device, - uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width); +// template Result conv2d( +// const ttnn::Tensor& input_tensor, +// const ttnn::Tensor& weight_tensor, +// Device * device, +// uint32_t in_channels, +// uint32_t out_channels, +// uint32_t batch_size, +// uint32_t input_height, +// uint32_t input_width, +// std::array kernel_size, +// std::array stride, +// std::array padding, +// std::array dilation, +// uint32_t groups, +// std::optional bias_tensor, +// std::optional conv_config_, +// const std::optional memory_config){ +// return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_, memory_config); +// } + +// template Result conv2d( +// const ttnn::Tensor& input_tensor, +// const ttnn::Tensor& weight_tensor, +// MeshDevice * device, +// uint32_t in_channels, +// uint32_t out_channels, +// uint32_t batch_size, +// uint32_t input_height, +// uint32_t input_width, +// std::array kernel_size, +// std::array stride, +// std::array padding, +// std::array dilation, +// uint32_t groups, +// std::optional bias_tensor, +// std::optional conv_config_, +// const std::optional memory_config){ +// return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_, memory_config); +// } Result Conv2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - Device * device, + MeshDevice * device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -1117,7 +367,7 @@ Result Conv2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - MeshDevice * device, + Device * device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index 47c6cea3df2..4722cdbc86b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -6,6 +6,7 @@ #include #include +#include "conv2d_utils.hpp" #include "ttnn/core.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/operations/matmul/matmul.hpp" @@ -29,173 +30,6 @@ using OutputHeight = uint32_t; using OutputWidth = uint32_t; using Result = std::tuple>; -struct Conv2dConfig { - MathFidelity math_fidelity = MathFidelity::HiFi4; - DataType dtype = DataType::BFLOAT16; - DataType weights_dtype = DataType::BFLOAT16; - bool math_approx_mode_enabled = true; - bool fp32_dest_acc_enabled = false; - bool packer_l1_accum_enabled = false; - string activation = ""; - uint32_t input_channels_alignment = 32; - bool deallocate_activation = false; - bool reallocate_halo_output = false; - uint32_t act_block_h_override = 0; // This argument is ignored when shard_layout == WIDTH_SHARDED. - uint32_t act_block_w_div = 1; //Amount by which the maximum possible act_block_width is divided. Max act_block_w = (in_channels * window_w * window_h)/total_num_cores; - //Ignored when shard_layout == HEIGHT_SHARDED or BLOCK_SHARDED - bool reshard_if_not_optimal = false; // if true, override_sharding_config should not be set to true - bool override_sharding_config = false; // if true, reshard_if_not_optimal should not be set to true - std::optional shard_layout; - std::optional core_grid = std::nullopt; // used only if override_sharding_config is true - bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false - Layout output_layout = Layout::TILE; - bool enable_act_double_buffer = false; - bool enable_weights_double_buffer = false; // Used on for block sharded convolutions - bool enable_split_reader = false; - bool enable_subblock_padding = false; - static constexpr auto attribute_names = std::make_tuple( - "math_fidelity", - "dtype", - "weights_dtype", - "math_approx_mode_enabled", - "fp32_dest_acc_enabled", - "packer_l1_accum_enabled", - "activation", - "input_channels_alignment", - "deallocate_activation", - "reallocate_halo_output", - "act_block_h_override", - "act_block_w_div", - "reshard_if_not_optimal", - "override_sharding_config", - "shard_layout", - "core_grid", - "transpose_shards", - "output_layout", - "enable_act_double_buffer", - "enable_weights_double_buffer", - "enable_split_reader", - "enable_subblock_padding"); - const auto attribute_values() const { - return std::make_tuple( - std::cref(this->math_fidelity), - std::cref(this->dtype), - std::cref(this->weights_dtype), - std::cref(this->math_approx_mode_enabled), - std::cref(this->fp32_dest_acc_enabled), - std::cref(this->packer_l1_accum_enabled), - std::cref(this->activation), - std::cref(this->input_channels_alignment), - std::cref(this->deallocate_activation), - std::cref(this->reallocate_halo_output), - std::cref(this->act_block_h_override), - std::cref(this->act_block_w_div), - std::cref(this->reshard_if_not_optimal), - std::cref(this->override_sharding_config), - std::cref(this->shard_layout), - std::cref(this->core_grid), - std::cref(this->transpose_shards), - std::cref(this->output_layout), - std::cref(this->enable_act_double_buffer), - std::cref(this->enable_weights_double_buffer), - std::cref(this->enable_split_reader), - std::cref(this->enable_subblock_padding)); - } -}; - -uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor); - -uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor); - -uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor); - -bool use_matmul_for_1x1_conv( - const std::array& kernel_size, - const std::array& stride, - const std::array& padding, - const std::array& dilation, - uint32_t groups); - -sliding_window::ParallelConfig determine_parallel_config( - const TensorMemoryLayout shard_layout, - uint32_t batch_size, - uint32_t input_channels, - uint32_t output_height, - uint32_t output_width, - uint32_t output_channels, - const CoreCoord& compute_grid_size, - ShardOrientation block_shard_orientation, - bool is_out_tiled=true); - -uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); - -uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); - -MemoryConfig create_sharded_memory_config_from_parallel_config(const ttnn::Shape& tensor_shape, sliding_window::ParallelConfig& parallel_config, uint32_t tile_size); - -OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( - const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); - -std::pair determine_largest_subblock_size(uint32_t block_height, uint32_t block_width, bool fp32_accum); - -OptimizedConvBlockConfig determine_per_core_conv_block_config( - const sliding_window::ParallelConfig& parallel_config, - const OptimizedConvParallelizationConfig& conv_op_parallel_config, - uint32_t padded_in_channels, - uint32_t act_block_h_override, - uint32_t act_block_w_div, - uint32_t window_h, - uint32_t window_w, - bool fp32_accum, - bool split_reader_enabled); - -template -std::tuple get_conv_padded_input_shape_and_mem_config( - T * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - -template -std::tuple shard_or_reshard_tensor_if_required( - T* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels, - bool is_mm_conv); - -void validate_weight_and_bias_tensors(const ttnn::Tensor& weight_tensor, std::optional& bias_tensor); - -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - Tensor conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype = std::nullopt); - -// Converts convolution weights to tilized 2d matrix layout with special block height padding -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - Tensor conv_weight_tensor, - uint32_t in1_block_h, - uint32_t in1_block_w, - std::optional output_dtype = std::nullopt); - -// Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype); - -template -std::pair> prepare_conv_weights_biases_and_move_to_device(const ttnn::Tensor& weight_tensor, std::optional& bias_tensor, uint32_t input_channels_alignment, DataType weights_bias_dtype, uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const sliding_window::ParallelConfig& parallel_config, T * device, uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width); - template Result conv2d( const ttnn::Tensor& input_tensor, @@ -215,7 +49,6 @@ Result conv2d( std::optional conv_config_ = std::nullopt, const std::optional memory_config = std::nullopt); - struct Conv2dOperation{ static Result invoke( uint8_t queue_id, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index d00c50b46d0..b7fcddfb07b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -9,6 +9,9 @@ #include "conv2d_pybind.hpp" #include "ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.hpp" #include "conv2d.hpp" +#include "conv2d_utils.hpp" +#include "prepare_conv2d_weights.hpp" +#include "ttnn/types.hpp" namespace py = pybind11; @@ -120,6 +123,85 @@ void py_bind_conv2d(py::module& module) { py::arg("queue_id") = 0} ); + module.def( + "prepare_conv_weights", + prepare_conv_weights, + py::kw_only(), + py::arg("weight_tensor"), + py::arg("input_memory_config"), + py::arg("weights_format"), + py::arg("in_channels"), + py::arg("out_channels"), + py::arg("batch_size"), + py::arg("input_height"), + py::arg("input_width"), + py::arg("kernel_size"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::arg("groups"), + py::arg("device"), + py::arg("conv_config") = std::nullopt); + + + module.def( + "prepare_conv_weights", + prepare_conv_weights, + py::kw_only(), + py::arg("weight_tensor"), + py::arg("input_memory_config"), + py::arg("weights_format"), + py::arg("in_channels"), + py::arg("out_channels"), + py::arg("batch_size"), + py::arg("input_height"), + py::arg("input_width"), + py::arg("kernel_size"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::arg("groups"), + py::arg("device"), + py::arg("conv_config") = std::nullopt); + + module.def( + "prepare_conv_bias", + prepare_conv_bias, + py::kw_only(), + py::arg("bias_tensor"), + py::arg("input_memory_config"), + py::arg("in_channels"), + py::arg("out_channels"), + py::arg("batch_size"), + py::arg("input_height"), + py::arg("input_width"), + py::arg("kernel_size"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::arg("groups"), + py::arg("device"), + py::arg("conv_config") = std::nullopt); + + module.def( + "prepare_conv_bias", + prepare_conv_bias, + py::kw_only(), + py::arg("bias_tensor"), + py::arg("input_memory_config"), + py::arg("in_channels"), + py::arg("out_channels"), + py::arg("batch_size"), + py::arg("input_height"), + py::arg("input_width"), + py::arg("kernel_size"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::arg("groups"), + py::arg("device"), + py::arg("conv_config") = std::nullopt); + module.def( "get_conv_padded_input_shape_and_mem_config", [](ttnn::Device* device, @@ -130,7 +212,7 @@ void py_bind_conv2d(py::module& module) { uint32_t width, uint32_t in_channels, uint32_t out_channels) -> std::tuple { - return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( + return get_conv_padded_input_shape_and_mem_config( device, input_tensor, conv_config, @@ -160,7 +242,7 @@ void py_bind_conv2d(py::module& module) { uint32_t width, uint32_t in_channels, uint32_t out_channels) -> std::tuple { - return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( + return get_conv_padded_input_shape_and_mem_config( device, input_tensor, conv_config, @@ -182,7 +264,7 @@ void py_bind_conv2d(py::module& module) { module.def( "convert_conv_weight_tensor_to_tiled_layout", - &ttnn::operations::conv::conv2d::convert_conv_weight_tensor_to_tiled_layout, + &convert_conv_weight_tensor_to_tiled_layout, py::arg("conv_weight_tensor").noconvert(), py::arg("in1_block_h"), py::arg("in1_block_w"), @@ -190,7 +272,7 @@ void py_bind_conv2d(py::module& module) { module.def( "convert_conv_weight_tensor_to_special_padding_tiled_layout", - &ttnn::operations::conv::conv2d::convert_conv_weight_tensor_to_special_padding_tiled_layout, + &convert_conv_weight_tensor_to_special_padding_tiled_layout, py::arg("conv_weight_tensor").noconvert(), py::arg("in1_block_h"), py::arg("in1_block_w"), @@ -198,7 +280,7 @@ void py_bind_conv2d(py::module& module) { module.def( "convert_conv_weight_tensor_to_grouped_layout", - &ttnn::operations::conv::conv2d::convert_conv_weight_tensor_to_grouped_layout, + &convert_conv_weight_tensor_to_grouped_layout, py::arg("conv_weight_tensor").noconvert(), py::arg("num_groups"), py::arg("output_dtype").noconvert() = std::nullopt); @@ -214,7 +296,7 @@ void py_bind_conv2d(py::module& module) { const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { - return ttnn::operations::conv::conv2d::determine_parallel_config( + return determine_parallel_config( shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_out_tiled); }, py::arg("shard_layout"), @@ -229,7 +311,7 @@ void py_bind_conv2d(py::module& module) { module.def( "create_sharded_memory_config_from_parallel_config", - &ttnn::operations::conv::conv2d::create_sharded_memory_config_from_parallel_config, + &create_sharded_memory_config_from_parallel_config, py::arg("tensor_shape"), py::arg("parallel_config"), py::arg("tile_size")); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp new file mode 100644 index 00000000000..d3cf9c3eb73 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -0,0 +1,743 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "conv2d_utils.hpp" +#include "common/constants.hpp" +#include "impl/buffers/buffer_constants.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/pool/downsample/device/downsample_op.hpp" +#include "tt_metal/detail/reports/memory_reporter.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "tt_metal/common/core_coord.hpp" + +using namespace tt; +namespace ttnn { +namespace operations::conv { +using sliding_window::SlidingWindowConfig; +using sliding_window::ParallelConfig; + + +uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { + uint32_t divisor = start_divisor; + while (num % divisor != 0) divisor = divisor - 1; + return divisor; +} + +uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { + uint32_t divisor = start_divisor; + uint32_t padded_num = round_up(num, divisor); + while ((padded_num - num) >= (int)(padded_num / divisor)) { + divisor = divisor - 1; + padded_num = round_up(num, divisor); + } + return divisor; +} + +uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { + uint32_t divisor = start_divisor; + while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + return divisor; +} + +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout( + Tensor conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype){ + return tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype); + } + +// Converts convolution weights to tilized 2d matrix layout with special block height padding +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( + Tensor conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype){ + return tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype); + } + +// Converts convolution weights to grouped layout with padded zeros +Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype){ + return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(conv_weight_tensor, num_groups, output_dtype); +} + +ParallelConfig determine_parallel_config( + const TensorMemoryLayout shard_layout, + uint32_t batch_size, + uint32_t input_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t output_channels, + const CoreCoord& compute_grid_size, + ShardOrientation block_shard_orientation, + bool is_out_tiled) { + + uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; + uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1; + uint32_t out_nhw_ntiles = tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; + uint32_t out_c_ntiles = tt::round_up(output_channels, effective_tile_width) / effective_tile_width; + + // calculate num_core_nhw and the grid + uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; + uint32_t num_cores_nhw = 0; + CoreRangeSet grid; + if (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + num_cores_nhw = find_closest_largest_divisor(out_nhw_ntiles, max_num_cores); + if (num_cores_nhw < compute_grid_size.x && out_nhw_ntiles > compute_grid_size.x) { + num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, compute_grid_size.x); + } + grid = num_cores_to_corerangeset(num_cores_nhw, compute_grid_size, true); + } else if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { + uint32_t start_divisor = + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; + num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); + uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x); + uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; + uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; + CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); + grid = CoreRangeSet({core_range}); + } else if (shard_layout == TensorMemoryLayout::WIDTH_SHARDED) { + num_cores_nhw = 1; + uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), max_num_cores); + grid = num_cores_to_corerangeset(num_cores_c, compute_grid_size, true); + } else { + TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); + } + + auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation : ShardOrientation::ROW_MAJOR; // NOTE: taking ROW_MAJOR as default orientation for HEIGHT_SHARDED and WIDTH_SHARDED + ParallelConfig pconfig = { + .grid = grid, + .shard_scheme = shard_layout, + .shard_orientation = shard_orientation }; + + return pconfig; +} + +uint32_t get_num_cores_nhw_from_parallel_config(const ParallelConfig& pconfig) { + TT_ASSERT(!pconfig.grid.ranges().empty()); + TT_ASSERT( + pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED || + pconfig.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED || + pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED); + auto grid_size = pconfig.grid.bounding_box().grid_size(); + uint32_t num_cores = pconfig.grid.num_cores(); + uint32_t num_cores_nhw = 0; + if(pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + return 1; + } + + if (pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { + num_cores_nhw = num_cores; + } else if (pconfig.shard_orientation == ShardOrientation::COL_MAJOR) { + num_cores_nhw = grid_size.x; + } else { + TT_ASSERT(pconfig.shard_orientation == ShardOrientation::ROW_MAJOR); + num_cores_nhw = grid_size.y; + } + + TT_ASSERT(num_cores_nhw > 0); + return num_cores_nhw; +} + +uint32_t get_num_cores_channels_from_parallel_config(const ParallelConfig& pconfig) { + TT_ASSERT(!pconfig.grid.ranges().empty()); + TT_ASSERT( + pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED || + pconfig.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED || + pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED); + auto grid_size = pconfig.grid.bounding_box().grid_size(); + uint32_t num_cores_channels = 0; + if (pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { + num_cores_channels = 1; + } else if(pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + num_cores_channels = pconfig.grid.num_cores(); + } else if (pconfig.shard_orientation == ShardOrientation::COL_MAJOR) { + num_cores_channels = grid_size.y; + } else { + TT_ASSERT(pconfig.shard_orientation == ShardOrientation::ROW_MAJOR); + num_cores_channels = grid_size.x; + } + TT_ASSERT(num_cores_channels > 0); + return num_cores_channels; +} + +MemoryConfig create_sharded_memory_config_from_parallel_config( + const ttnn::Shape& tensor_shape, const ParallelConfig& parallel_config, uint32_t tile_size) { + + log_debug(tt::LogOp, "create_sharded_memory_config_from_parallel_config: tensor_shape: {}, parallel_config: {}, tile_size: {}", tensor_shape, parallel_config, tile_size); + // tensor_shape is [N, H, W, C] + TT_ASSERT(tensor_shape[0] == 1 && tensor_shape[1] == 1); // todo: add support for generic non-2d shapes + // uint32_t channels = tensor_shape[3]; + uint32_t channels = tensor_shape.with_tile_padding()[3]; + uint32_t num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); + uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); + auto shard_scheme = parallel_config.shard_scheme; + auto shard_orientation = parallel_config.shard_orientation; + + uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2]; + uint32_t nhw_padded = nhw_shape; + if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { + nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); + } + uint32_t nhw_shard = nhw_padded / num_cores_nhw; + TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels); + uint32_t channel_shard = channels / num_cores_channels; + auto shard_spec = ShardSpec{parallel_config.grid, {nhw_shard, channel_shard}, shard_orientation}; + return MemoryConfig{shard_scheme, BufferType::L1, shard_spec}; +} + + +OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( + const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c) { + TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); + const auto& shard_spec = conv_output_mem_config.shard_spec.value(); + const auto& shard_shape = shard_spec.shape; + TT_ASSERT(shard_shape[1] % 32 == 0); + uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32); + return { + .grid_size = shard_spec.grid.bounding_box().grid_size(), + .num_cores_nhw = num_cores_nhw, + .num_cores_c = num_cores_c, + .per_core_out_matrix_height_ntiles = per_core_out_matrix_height_ntiles, + .per_core_out_matrix_width_ntiles = shard_shape[1] / 32, + .per_core_out_matrix_height = shard_shape[0], + .per_core_out_matrix_width = shard_shape[1], + }; +} + +std::pair determine_largest_subblock_size( + uint32_t block_height, uint32_t block_width, bool fp32_accum, bool split_reader_enabled) { + constexpr std::array, 20> subblocks = {{ + {2, 4}, {4, 2}, {1, 8}, {8, 1}, {1, 7}, {7, 1}, {2, 3}, {3, 2}, {1, 6}, {6, 1}, + {1, 5}, {5, 1}, {2, 2}, {1, 4}, {4, 1}, {1, 3}, {3, 1}, {1, 2}, {2, 1}, {1, 1}, + }}; + + uint32_t subblock_h = 0; + uint32_t subblock_w = 0; + for (auto [subblock_height, subblock_width] : subblocks) { + if (fp32_accum && (subblock_height * subblock_width > 4)) { + continue; + } + + if (split_reader_enabled && (block_height / subblock_height) < 2) { + continue; + } + + if ((block_height % subblock_height == 0) && (block_width % subblock_width == 0)) { + if (subblock_width != block_width && subblock_height != 1) { + continue; + } + subblock_h = subblock_height; + subblock_w = subblock_width; + break; + } + } + TT_ASSERT(subblock_h > 0 && subblock_w > 0); + return {subblock_h, subblock_w}; +} + +OptimizedConvBlockConfig determine_per_core_conv_block_config( + const ParallelConfig& parallel_config, + const OptimizedConvParallelizationConfig& conv_op_parallel_config, + uint32_t padded_in_channels, + uint32_t act_block_h_override, + uint32_t act_block_w_div, + uint32_t window_h, + uint32_t window_w, + bool fp32_accum, + bool split_reader_enabled) { + + if (act_block_h_override > 0) { + TT_ASSERT( + act_block_h_override % 32 == 0, + "Config Error: act_block_h_override must be a multiple of 32 (tile height)."); + } + auto grid_size = parallel_config.grid.bounding_box().grid_size(); + uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; + if (parallel_config.shard_scheme != TensorMemoryLayout::WIDTH_SHARDED && act_block_h_override > 0 ) { + log_debug(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); + act_block_h_ntiles = act_block_h_override / constants::TILE_HEIGHT; + } + uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED + ? round_up(padded_in_channels * window_w, 32) + : padded_in_channels; + if(parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + act_block_w = (padded_in_channels * window_h * window_w)/(parallel_config.grid.num_cores() * act_block_w_div); + } + TT_ASSERT(act_block_w % 32 == 0); + uint32_t act_block_w_ntiles = act_block_w / 32; + uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 + : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y + : grid_size.x; + uint32_t out_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; + uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntiles; + //act_block_h_ntiles / block_config.out_subblock_h_ntiles) >= 2 + auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = + determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled); + return { + .act_block_h_ntiles = act_block_h_ntiles, + .act_block_w_ntiles = act_block_w_ntiles, + .out_subblock_h_ntiles = out_subblock_h_ntiles, + .out_subblock_w_ntiles = out_subblock_w_ntiles}; +} + +bool use_matmul_for_1x1_conv( + const std::array& kernel_size, + const std::array& stride, + const std::array& padding, + const std::array& dilation, + uint32_t groups) { + return kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] && stride[0] == 1 && padding[0] == 0 && + padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 && groups == 1; +} + +// Implements a heuristic for selecting shard layout based on how many tenix cores are available +// for each shard. +static TensorMemoryLayout select_shard_spec( + bool is_mm_conv, + uint32_t batch_size, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t weights_width, + uint32_t input_width, + ShardOrientation shard_orientation, + const std::array& kernel_size, + const std::array& stride, + const CoreCoord& compute_grid_size) { + auto get_core_count_for_sharding = [&](TensorMemoryLayout shard_layout) { + return determine_parallel_config( + shard_layout, + batch_size, + in_channels, + output_height, + output_width, + out_channels, + compute_grid_size, + shard_orientation) + .grid.num_cores(); + }; + + // Block sharding supports very few kernel dims. + const bool is_block_sharding_valid = + (kernel_size[0] == 3 && kernel_size[1] == 3 && (stride[0] == 1 || stride[0] == 2)) || + (kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == 2); + + // 1d convs support only height sharding + const bool is_conv1d = weights_width == 1 && input_width == 1; + + const uint32_t cc_height = get_core_count_for_sharding(TensorMemoryLayout::HEIGHT_SHARDED); + // matmul doesn't support width sharding + const uint32_t cc_width = + !is_mm_conv && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::WIDTH_SHARDED) : 0; + const uint32_t cc_block = + (is_block_sharding_valid || is_mm_conv) && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::BLOCK_SHARDED) : 0; + + uint32_t max_cc = cc_block; + TensorMemoryLayout shard_layout = TensorMemoryLayout::BLOCK_SHARDED; + + // Prefer block sharding over height sharding but make sure that we got at least + // some blocking on width dimension as well. + if (cc_height > max_cc || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { + shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; + max_cc = cc_height; + } + + if (cc_width >= max_cc) { + shard_layout = TensorMemoryLayout::WIDTH_SHARDED; + max_cc = cc_width; + } + + if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { + // For large number of input channels prefer width sharding + // even if it has less cores. + // For BH we probably need to adjust this, or even better we make block sharding + // more configurable rearding l1 memory usage for weights. + if (cc_width >= 40 && in_channels > 1280) { + shard_layout = TensorMemoryLayout::WIDTH_SHARDED; + log_debug(LogOp, "Switching to WIDTH_SHARDED layout due to large in_channels"); + max_cc = cc_width; + } + } + log_debug(LogOp, "Selected shard layout: {}, num cores: {}", shard_layout, max_cc); + + return shard_layout; +} + +template +std::tuple get_conv_padded_input_shape_and_mem_config( + T* device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels) { + ttnn::Tensor input_tensor = input_tensor_; // tensor to return + bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); + bool needs_shard_or_reshard = false; + if (conv_config.override_sharding_config && conv_config.reshard_if_not_optimal) { + TT_ASSERT( + false, + "Incorrect config provided: reshard_if_not_optimal and override_sharding_config cannot both be set."); + } + + TT_FATAL( + (!input_tensor_on_device || input_tensor_.is_sharded()) || conv_config.shard_layout.has_value(), + "Tesor must be sharded or shard_layout must be set."); + + TensorMemoryLayout shard_layout; + if (conv_config.shard_layout.has_value()) { + shard_layout = conv_config.shard_layout.value(); + } + + ParallelConfig input_tensor_parallel_config; + if (!input_tensor_on_device) { + needs_shard_or_reshard = true; + } else { + const auto& input_memory_config = input_tensor_.memory_config(); + if (!input_memory_config.is_sharded()) { + needs_shard_or_reshard = true; + } else { + const auto input_shard_scheme = input_memory_config.memory_layout; + const auto input_shard_orientation = input_memory_config.shard_spec.value().orientation; + const auto input_shard_grid = input_memory_config.shard_spec.value().grid; + ParallelConfig pconfig = { + .grid = input_shard_grid, + .shard_scheme = input_shard_scheme, + .shard_orientation = input_shard_orientation}; + input_tensor_parallel_config = pconfig; + if (input_shard_scheme != TensorMemoryLayout::BLOCK_SHARDED && + input_shard_orientation != ShardOrientation::ROW_MAJOR) { + needs_shard_or_reshard = true; + } + if (input_shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED && + input_shard_scheme != TensorMemoryLayout::BLOCK_SHARDED && + input_shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { + needs_shard_or_reshard = true; + } + if (conv_config.override_sharding_config) { + TT_FATAL(conv_config.core_grid.has_value(), "If override_sharding_config is set, core_grid must be set as well."); + TT_FATAL(conv_config.shard_layout.has_value(), "If override_sharding_config is set, shard_layout must be set as well."); + if (conv_config.core_grid.value() != input_shard_grid) { + needs_shard_or_reshard = true; + } + if(shard_layout!=input_shard_scheme) { + needs_shard_or_reshard = true; + } + bool input_transpose_shards = input_shard_orientation == ShardOrientation::COL_MAJOR; + if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED && conv_config.transpose_shards != input_transpose_shards) { + needs_shard_or_reshard = true; + } + } + } + } + + // shallow conv variriant not supported + // out_channels <= 256 incorrect output from pack_untilize_dst if output > 256 Tracking --> #14236 + // bf8 not supported due to limation of sharding dim multipl of 32 + bool use_non_tile_height = (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) && out_channels <= 256 && conv_config.act_block_h_override == 0 && + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR && conv_config.input_channels_alignment != 16; //shalow conv varient + + ParallelConfig parallel_config = input_tensor_parallel_config; + if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { + auto block_shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; + ParallelConfig optimal_parallel_config = determine_parallel_config( + shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height); + + if (conv_config.override_sharding_config) { + TT_FATAL(conv_config.core_grid.has_value(), "Error"); + // override parallel config + auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED + ? block_shard_orientation + : ShardOrientation::ROW_MAJOR; + parallel_config = { + .grid = conv_config.core_grid.value(), + .shard_scheme = shard_layout, + .shard_orientation = shard_orientation}; + } else { + parallel_config = optimal_parallel_config; + } + if (input_tensor_parallel_config != parallel_config) { + needs_shard_or_reshard = true; + } + } + if (needs_shard_or_reshard) { + uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); + // TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape()); + uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; + uint32_t round_up_size = (use_non_tile_height || conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED) ? 1 : tt::constants::TILE_HEIGHT; + uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); + TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); + uint32_t tensor_width = input_tensor.get_shape()[3]; + uint32_t input_tensor_width_snapped_to_channels_alignment = + tt::round_up(tensor_width, conv_config.input_channels_alignment); + TT_ASSERT(input_tensor_width_snapped_to_channels_alignment >= tensor_width); + + auto input_padded_shape = ttnn::Shape(std::array{ + 1, + 1, + input_tensor_height_snapped_to_tile, + input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and + // tt::tt_metal::LegacyShape issue to clean up next line + auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( + ttnn::Shape(std::array{ + input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), + parallel_config, round_up_size); + return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height}; + } else { + return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard, use_non_tile_height}; + } +} + +template +std::tuple shard_or_reshard_tensor_if_required( + T* device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels, + bool is_mm_conv) { + ttnn::Tensor input_tensor = input_tensor_; // tensor to return + bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); + + auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height] = + get_conv_padded_input_shape_and_mem_config( + device, + input_tensor_, + conv_config, + batch_size, + height, + width, + in_channels, + out_channels); + ParallelConfig parallel_config = { + .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, + .shard_scheme = input_tensor_sharded_memory_config.memory_layout, + .shard_orientation = input_tensor_sharded_memory_config.shard_spec.value().orientation + }; + if (needs_shard_or_reshard) { + if (input_tensor.get_shape()[0] != 1 or input_tensor.get_shape()[1] != 1) { + // reshape to [1, 1, N*H*W, C] + input_tensor = ttnn::reshape( + input_tensor, + ttnn::SimpleShape(std::array{ + 1, + 1, + input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2], + input_tensor.get_shape()[3]})); + } + + uint32_t tensor_height = input_tensor.get_shape()[2]; + uint32_t tensor_width = input_tensor.get_shape()[3]; + + if (!input_tensor_on_device) { + if (input_padded_shape[-2] != tensor_height || input_padded_shape[-1] != tensor_width) { + input_tensor = ttnn::pad( + input_tensor, + tt::tt_metal::Array4D({input_tensor.get_shape()[0], + input_tensor.get_shape()[1], + input_padded_shape[-2], + input_padded_shape[-1]}), + tt::tt_metal::Array4D({0, 0, 0, 0}), + 0); + } + } + + if (input_tensor_on_device) { + if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && + parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { + // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout + input_tensor = + ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); + } + auto resharded_input_tensor = ttnn::to_memory_config( + input_tensor, input_tensor_sharded_memory_config, std::nullopt); + if (conv_config.deallocate_activation) { + input_tensor.deallocate(); + resharded_input_tensor = ttnn::operations::core::reallocate(resharded_input_tensor, resharded_input_tensor.memory_config()); + } + input_tensor = resharded_input_tensor; + } else { + if (is_mm_conv && input_tensor.layout() == Layout::ROW_MAJOR && + parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) { + // Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout + input_tensor = ttnn::to_device(input_tensor, device, std::nullopt); + input_tensor = + ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); + input_tensor = ttnn::to_memory_config(input_tensor, input_tensor_sharded_memory_config, std::nullopt); + } else { + input_tensor = ttnn::to_device(input_tensor, device, input_tensor_sharded_memory_config); + } + } + } + return {input_tensor, parallel_config, needs_shard_or_reshard, use_non_tile_height}; +} + +void validate_weight_and_bias_tensors( + const ttnn::Tensor& weight_tensor, std::optional& bias_tensor) { + TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE)); + TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR); + TT_ASSERT(weight_tensor.get_shape().rank() == 4); + // TODO: enable this assert + // TT_ASSERT(weight_tensor.get_shape() == weight_tensor.get_legacy_shape()); + if (bias_tensor.has_value()) { + TT_ASSERT(!ttnn::has_storage_type_of(bias_tensor.value(), ttnn::DEVICE_STORAGE_TYPE)); + TT_ASSERT(bias_tensor.value().get_shape().rank() == 4); + TT_ASSERT(bias_tensor.value().get_layout() == Layout::ROW_MAJOR); + // TODO: enable this assert + // TT_ASSERT(bias_tensor.value().get_shape() == bias_tensor.value().get_legacy_shape()); + } +} + +ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_conv_op_config( + OptimizedConvParallelizationConfig conv_parallelization_config, + OptimizedConvBlockConfig conv_blocking_config, + bool height_sharded, + string activation, + bool transpose_mcast, + uint32_t grid_size_along_c) { + if (height_sharded) { + ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig matmul_config = { + .compute_with_storage_grid_size = conv_parallelization_config.grid_size, + .in0_block_w = conv_blocking_config.act_block_w_ntiles, + .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, + .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, + .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntiles, + .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntiles, + .fuse_batch = true, + .mcast_in0 = false}; + if (activation != "") { + matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation); + } + return matmul_config; + } else { + TT_ASSERT(conv_blocking_config.act_block_w_ntiles % grid_size_along_c == 0); + ttnn::operations::matmul::MatmulMultiCoreReuseMultiCastProgramConfig matmul_config = { + .compute_with_storage_grid_size = conv_parallelization_config.grid_size, + .in0_block_w = conv_blocking_config.act_block_w_ntiles / grid_size_along_c, + .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, + .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, + .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntiles, + .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntiles, + .transpose_mcast = transpose_mcast}; + if (activation != "") { + matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation); + } + return matmul_config; + } +} + +void adjust_conv_op_config_for_auto_shard_if_necessary( + bool is_mm_conv, + uint32_t batch_size, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t weights_width, + uint32_t input_width, + const std::array& kernel_size, + const std::array& stride, + const CoreCoord& compute_grid_size, + Conv2dConfig& conv_config, + std::optional input_memory_config) { + + // If the input tensor is already sharded, or the conv_config has a specified shard layout, we don't need to do anything. + if ((input_memory_config.has_value() && input_memory_config.value().is_sharded()) || conv_config.shard_layout.has_value()) { + return; + } + + ShardOrientation shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; + conv_config.shard_layout = select_shard_spec( + is_mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + weights_width, + input_width, + shard_orientation, + kernel_size, + stride, + compute_grid_size); + + if (conv_config.act_block_h_override == 0 && conv_config.shard_layout != TensorMemoryLayout::WIDTH_SHARDED) { + if (in_channels <= constants::TILE_WIDTH / 2 && conv_config.input_channels_alignment == constants::TILE_WIDTH && + !is_mm_conv && conv_config.shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + log_debug(LogOp, "Auto shard, enable shallow conv"); + // height sharded, non matmul conv, with input channels <= 16, and default setting for + // input_channels_alignment + conv_config.input_channels_alignment = constants::TILE_WIDTH / 2; + } + + // Set act_block_h_override to min value to + // be conservative with L1 memory usage. + conv_config.act_block_h_override = constants::TILE_HEIGHT; + } +} + +template std::tuple get_conv_padded_input_shape_and_mem_config( + Device* device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels); + +template std::tuple get_conv_padded_input_shape_and_mem_config( + MeshDevice * device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels); + +template std::tuple shard_or_reshard_tensor_if_required( + Device* device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels, + bool is_mm_conv); + +template std::tuple shard_or_reshard_tensor_if_required( + MeshDevice * device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channel, + bool is_mm_conv); + +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp new file mode 100644 index 00000000000..5b9583d23ef --- /dev/null +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -0,0 +1,234 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include +#include + +#include "ttnn/core.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/matmul/matmul.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" +#include "ttnn/types.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "tt_metal/impl/dispatch/command_queue.hpp" +#include "tt_metal/common/math.hpp" +#include "ttnn/operations/data_movement/pad/pad.hpp" +#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/sliding_window/sliding_window.hpp" +#include "ttnn/operations/sliding_window/halo/halo.hpp" +#include "tt_metal/common/core_coord.hpp" + +namespace ttnn { + +namespace operations::conv { +using namespace conv2d; +using OutputHeight = uint32_t; +using OutputWidth = uint32_t; +using Result = std::tuple>; + +struct Conv2dConfig { + MathFidelity math_fidelity = MathFidelity::HiFi4; + DataType dtype = DataType::BFLOAT16; + DataType weights_dtype = DataType::BFLOAT16; + bool math_approx_mode_enabled = true; + bool fp32_dest_acc_enabled = false; + bool packer_l1_accum_enabled = false; + string activation = ""; + uint32_t input_channels_alignment = 32; + bool deallocate_activation = false; + bool reallocate_halo_output = false; + uint32_t act_block_h_override = 0; // This argument is ignored when shard_layout == WIDTH_SHARDED. + uint32_t act_block_w_div = 1; //Amount by which the maximum possible act_block_width is divided. Max act_block_w = (in_channels * window_w * window_h)/total_num_cores; + //Ignored when shard_layout == HEIGHT_SHARDED or BLOCK_SHARDED + bool reshard_if_not_optimal = false; // if true, override_sharding_config should not be set to true + bool override_sharding_config = false; // if true, reshard_if_not_optimal should not be set to true + std::optional shard_layout; + std::optional core_grid = std::nullopt; // used only if override_sharding_config is true + bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false + Layout output_layout = Layout::TILE; + bool enable_act_double_buffer = false; + bool enable_weights_double_buffer = false; // Used on for block sharded convolutions + bool enable_split_reader = false; + bool enable_subblock_padding = false; + static constexpr auto attribute_names = std::make_tuple( + "math_fidelity", + "dtype", + "weights_dtype", + "math_approx_mode_enabled", + "fp32_dest_acc_enabled", + "packer_l1_accum_enabled", + "activation", + "input_channels_alignment", + "deallocate_activation", + "reallocate_halo_output", + "act_block_h_override", + "act_block_w_div", + "reshard_if_not_optimal", + "override_sharding_config", + "shard_layout", + "core_grid", + "transpose_shards", + "output_layout", + "enable_act_double_buffer", + "enable_weights_double_buffer", + "enable_split_reader", + "enable_subblock_padding"); + const auto attribute_values() const { + return std::make_tuple( + std::cref(this->math_fidelity), + std::cref(this->dtype), + std::cref(this->weights_dtype), + std::cref(this->math_approx_mode_enabled), + std::cref(this->fp32_dest_acc_enabled), + std::cref(this->packer_l1_accum_enabled), + std::cref(this->activation), + std::cref(this->input_channels_alignment), + std::cref(this->deallocate_activation), + std::cref(this->reallocate_halo_output), + std::cref(this->act_block_h_override), + std::cref(this->act_block_w_div), + std::cref(this->reshard_if_not_optimal), + std::cref(this->override_sharding_config), + std::cref(this->shard_layout), + std::cref(this->core_grid), + std::cref(this->transpose_shards), + std::cref(this->output_layout), + std::cref(this->enable_act_double_buffer), + std::cref(this->enable_weights_double_buffer), + std::cref(this->enable_split_reader), + std::cref(this->enable_subblock_padding)); + } +}; + +uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor); + +uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor); + +uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor); + +bool use_matmul_for_1x1_conv( + const std::array& kernel_size, + const std::array& stride, + const std::array& padding, + const std::array& dilation, + uint32_t groups); + +sliding_window::ParallelConfig determine_parallel_config( + const TensorMemoryLayout shard_layout, + uint32_t batch_size, + uint32_t input_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t output_channels, + const CoreCoord& compute_grid_size, + ShardOrientation block_shard_orientation, + bool is_out_tiled=true); + +uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); + +uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); + +MemoryConfig create_sharded_memory_config_from_parallel_config(const ttnn::Shape& tensor_shape, const sliding_window::ParallelConfig& parallel_config, uint32_t tile_size); + +OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config(const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw); + +std::pair determine_largest_subblock_size(uint32_t block_height, uint32_t block_width, bool fp32_accum); + +ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_conv_op_config( + OptimizedConvParallelizationConfig conv_parallelization_config, + OptimizedConvBlockConfig conv_blocking_config, + bool height_sharded, + string activation, + bool transpose_mcast, + uint32_t grid_size_along_c); + +OptimizedConvBlockConfig determine_per_core_conv_block_config( + const sliding_window::ParallelConfig& parallel_config, + const OptimizedConvParallelizationConfig& conv_op_parallel_config, + uint32_t padded_in_channels, + uint32_t act_block_h_override, + uint32_t act_block_w_div, + uint32_t window_h, + uint32_t window_w, + bool fp32_accum, + bool split_reader_enabled); + +template +std::tuple get_conv_padded_input_shape_and_mem_config( + T * device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels); + +OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( + const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); + +void adjust_conv_op_config_for_auto_shard_if_necessary( + bool is_mm_conv, + uint32_t batch_size, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t weights_width, + uint32_t input_width, + const std::array& kernel_size, + const std::array& stride, + const CoreCoord& compute_grid_size, + Conv2dConfig& conv_config, + std::optional input_memory_config); + +template +std::tuple shard_or_reshard_tensor_if_required( + T* device, + const ttnn::Tensor& input_tensor_, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels, + bool is_mm_conv); + +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout( + Tensor conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype = std::nullopt); + +// Converts convolution weights to tilized 2d matrix layout with special block height padding +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( + Tensor conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype = std::nullopt); + +// Converts convolution weights to grouped layout with padded zeros +Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype); + +template +OptimizedConvBlockConfig get_opt_block_config( + bool mm_conv, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t batch_size, + uint32_t input_width, + std::array kernel_size, + std::array stride, + T *device, + Conv2dConfig& conv_config); + +} // namespace operations::conv +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp new file mode 100644 index 00000000000..c89b3e8fa8a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -0,0 +1,381 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "prepare_conv2d_weights.hpp" +#include "conv2d_utils.hpp" +#include +#include + +using namespace tt; +namespace ttnn { +namespace operations::conv { +using sliding_window::SlidingWindowConfig; +using sliding_window::ParallelConfig; + +namespace conv2d { + +void validate_weight_tensor(const ttnn::Tensor& weight_tensor) { + TT_ASSERT(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE)); + TT_ASSERT(weight_tensor.get_layout() == Layout::ROW_MAJOR); + TT_ASSERT(weight_tensor.get_shape().rank() == 4); +} + +void validate_bias_tensor(const ttnn::Tensor& bias_tensor) { + TT_ASSERT(!ttnn::has_storage_type_of(bias_tensor, ttnn::DEVICE_STORAGE_TYPE)); + TT_ASSERT(bias_tensor.get_shape().rank() == 4); + TT_ASSERT(bias_tensor.get_layout() == Layout::ROW_MAJOR); +} + +void validate_weights_format(std::string weights_format) { + TT_FATAL(weights_format.size() == 4, "weights_format must have exactly 4 characters"); + TT_ASSERT(weights_format.find("O") != string::npos, "weights_format must contain \"O\""); + TT_ASSERT(weights_format.find("I") != string::npos, "weights_format must contain \"I\""); + TT_ASSERT(weights_format.find("H") != string::npos, "weights_format must contain \"H\""); + TT_ASSERT(weights_format.find("W") != string::npos, "weights_format must contain \"W\""); + TT_ASSERT(weights_format == "OIHW", "Conv2d weights format must be \"OIHW\""); +} + +template +OptimizedConvBlockConfig get_opt_block_config( + bool mm_conv, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t batch_size, + uint32_t input_width, + std::array kernel_size, + std::array stride, + T *device, + Conv2dConfig& conv_config, + const MemoryConfig& input_memory_config) { + + + adjust_conv_op_config_for_auto_shard_if_necessary( + mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + kernel_size[1], + input_width, + kernel_size, + stride, + device->compute_with_storage_grid_size(), + conv_config, + input_memory_config); + + ShardOrientation shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; + + bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && + conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR; + use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; + + ParallelConfig parallel_config = determine_parallel_config( + conv_config.shard_layout.value(), + batch_size, + in_channels, + output_height, + output_width, + out_channels, + device->compute_with_storage_grid_size(), + shard_orientation, + !use_non_tile_height); + + + uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; + auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( + ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + parallel_config, round_up_size); + auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( + conv_out_memory_config, get_num_cores_nhw_from_parallel_config(parallel_config), + get_num_cores_channels_from_parallel_config(parallel_config)); + + return determine_per_core_conv_block_config( + parallel_config, + opt_conv_op_parallel_config, + tt::round_up(in_channels, conv_config.input_channels_alignment), + conv_config.act_block_h_override, + conv_config.act_block_w_div, + kernel_size[0], + kernel_size[1], + conv_config.fp32_dest_acc_enabled, + conv_config.enable_split_reader); +} + +template +ttnn::Tensor prepare_conv_weights( + const ttnn::Tensor& weight_tensor, + const ttnn::MemoryConfig &input_memory_config, + std::string weights_format, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + T *device, + std::optional conv_config_) { + + TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation."); + + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + const uint32_t output_width = + ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; + + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + auto opt_conv_op_block_config = get_opt_block_config( + mm_conv, + in_channels, + out_channels, + output_height, + output_width, + batch_size, + input_width, + kernel_size, + stride, + device, + conv_config, + input_memory_config + ); + + uint32_t weight_block_h_ntiles = opt_conv_op_block_config.act_block_w_ntiles; + uint32_t weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles; + uint32_t act_block_h_ntiles = opt_conv_op_block_config.act_block_h_ntiles; + + validate_weight_tensor(weight_tensor); + ttnn::Tensor weight_tensor_ = weight_tensor; // tensor to return + + // Permute to OIHW layout as thats what the preparation expects + validate_weights_format(weights_format); + + auto original_weights_shape = weight_tensor.get_shape(); + uint32_t original_weights_out_channels = original_weights_shape[0]; + uint32_t original_weights_in_channels = original_weights_shape[1]; + uint32_t original_weights_window_h = original_weights_shape[2]; + uint32_t original_weights_window_w = original_weights_shape[3]; + + bool is_conv1d = original_weights_window_w == 1 && input_width == 1; + bool is_depthwise_conv = groups == original_weights_out_channels && original_weights_in_channels == 1; + + weight_tensor_ = weight_tensor; + + // Convert weight tensor to 0 padded shape if groups > 1 + if (!is_conv1d and groups > 1) { + weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, conv_config.weights_dtype); + } + else if (is_conv1d and groups > 1) { + if (is_depthwise_conv) { + weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, conv_config.weights_dtype); + weight_block_h_ntiles = act_block_h_ntiles; + } + else{ + weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, conv_config.weights_dtype); + } + } + + auto weights_shape = weight_tensor_.get_shape(); + out_channels = weights_shape[0]; + in_channels = weights_shape[1]; + uint32_t window_h = weights_shape[2]; + uint32_t window_w = weights_shape[3]; + uint32_t out_channel_padding = tt::round_up(out_channels, 32) - out_channels; + tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( + {tt::round_up(out_channels, 32), tt::round_up(in_channels, conv_config.input_channels_alignment), window_h, window_w})); + if (conv_config.weights_dtype == DataType::BFLOAT8_B) { + TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); + } else { + // TODO: fix the need to check this. We should be able to accept any datatype and convert + TT_ASSERT(weight_tensor_.get_dtype() == conv_config.weights_dtype); + } + weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); + + // for conv op, pad the weights to block shape + if (conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED) { + weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( + weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, conv_config.weights_dtype); + } else { + weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( + weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, conv_config.weights_dtype); + } + + uint32_t weight_matrix_height = in_channels * window_h * window_w; + int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height; + TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative"); + + // convert_conv_weight_tensor adds the padding to the base shape. + // Reshape the weights to remove padding from the base shape. + weight_tensor_.set_shape( + ttnn::Shape(std::array{1, 1, weight_matrix_height, out_channels}, + std::array, 4>{ + std::array{0, 0}, + std::array{0, 0}, + std::array{0, weight_matrix_height_padding}, + std::array{0, out_channel_padding} + })); + return weight_tensor_; +} + +template +ttnn::Tensor prepare_conv_bias( + const ttnn::Tensor& bias_tensor, + const ttnn::MemoryConfig& input_memory_config, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + T *device, + std::optional conv_config_) { + + TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(bias_tensor), "Error: bias tensor must be on host for preparation."); + + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + const uint32_t output_width = + ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; + + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + auto opt_conv_op_block_config = get_opt_block_config( + mm_conv, + in_channels, + out_channels, + output_height, + output_width, + batch_size, + input_width, + kernel_size, + stride, + device, + conv_config, + input_memory_config + ); + + uint32_t weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles; + validate_bias_tensor(bias_tensor); + + ttnn::Tensor bias_tensor_; + bias_tensor_ = bias_tensor; + auto bias_shape = bias_tensor_.get_shape(); + TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1); + tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( + std::array({1, 1, 32, tt::round_up(out_channels, weight_block_w_ntiles * 32)})); + bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); + bias_tensor_ = ttnn::to_layout( + bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); + if (bias_tensor_.get_dtype() != conv_config.weights_dtype) { + bias_tensor_ = ttnn::to_dtype(bias_tensor_, conv_config.weights_dtype); + } + return bias_tensor_; +} + +template OptimizedConvBlockConfig get_opt_block_config( + bool mm_conv, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t batch_size, + uint32_t input_width, + std::array kernel_size, + std::array stride, + Device *device, + Conv2dConfig& conv_config, + const ttnn::MemoryConfig& input_memory_config); + +template OptimizedConvBlockConfig get_opt_block_config( + bool mm_conv, + uint32_t in_channels, + uint32_t out_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t batch_size, + uint32_t input_width, + std::array kernel_size, + std::array stride, + MeshDevice *device, + Conv2dConfig& conv_config, + const ttnn::MemoryConfig& input_memory_config); + +template ttnn::Tensor prepare_conv_weights( + const ttnn::Tensor& weight_tensor, + const ttnn::MemoryConfig& input_memory_config, + std::string weights_format, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + Device *device, + std::optional conv_config_); + +template ttnn::Tensor prepare_conv_weights( + const ttnn::Tensor& weight_tensor, + const ttnn::MemoryConfig& input_memory_config, + std::string weights_format, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + MeshDevice *device, + std::optional conv_config_); + +template ttnn::Tensor prepare_conv_bias( + const ttnn::Tensor& bias_tensor, + const ttnn::MemoryConfig& input_memory_config, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + Device *device, + std::optional conv_config_); + +template ttnn::Tensor prepare_conv_bias( + const ttnn::Tensor& bias_tensor, + const ttnn::MemoryConfig& input_memory_config, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + MeshDevice *device, + std::optional conv_config_); + +} // namespace conv2d +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp new file mode 100644 index 00000000000..7df17f1b41e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -0,0 +1,65 @@ + +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "conv2d_utils.hpp" +#include "ttnn/core.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/matmul/matmul.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" +#include "ttnn/types.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "tt_metal/impl/dispatch/command_queue.hpp" +#include "tt_metal/common/math.hpp" +#include "ttnn/operations/data_movement/pad/pad.hpp" +#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/sliding_window/sliding_window.hpp" +#include "ttnn/operations/sliding_window/halo/halo.hpp" + +namespace ttnn { + +namespace operations::conv { +namespace conv2d { +template +ttnn::Tensor prepare_conv_weights( + const ttnn::Tensor& weight_tensor, + const ttnn::MemoryConfig& input_memory_config, + std::string weights_format, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + T *device, + std::optional conv_config_); + +template +ttnn::Tensor prepare_conv_bias( + const ttnn::Tensor& bias_tensor, + const ttnn::MemoryConfig& input_memory_config, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + T *device, + std::optional conv_config_); + +} // namespace conv2d +} // namespace operations::conv +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index d0943e137fa..e272e262217 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -3,6 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #include "conv_transpose2d.hpp" +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" +#include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp" +#include "../conv2d/conv2d_utils.hpp" #include #include #include "common/bfloat16.hpp" @@ -102,8 +105,8 @@ Result conv_transpose2d( std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config_) { - conv2d::Conv2dConfig conv_config = conv_config_.value_or(conv2d::Conv2dConfig()); + std::optional conv_config_) { + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); //Inverse of sliding_window.get_output_shape() SlidingWindowConfig sliding_window_config = SlidingWindowConfig{ @@ -170,10 +173,10 @@ Result conv_transpose2d( TT_THROW("Invalid Device Arch, Got {}",device->arch()); } - const bool mm_conv = conv2d::use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); + const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); //Call Halo Transpose - auto [input_tensor_post_tm, parallel_config, tensor_manipulated, use_non_tile_height] = conv2d::shard_or_reshard_tensor_if_required( + auto [input_tensor_post_tm, parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( device, input_tensor, conv_config, @@ -187,7 +190,7 @@ Result conv_transpose2d( uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; - sliding_window_config.num_cores_nhw = conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); + sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid; sliding_window_config.snap_to_tile = !use_non_tile_height; @@ -211,16 +214,16 @@ Result conv_transpose2d( input_tensor_post_tm.memory_config()); //Call Conv2d u_op with Stride = 1, Padding = 0. - auto conv_out_memory_config = conv2d::create_sharded_memory_config_from_parallel_config( + auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), parallel_config, round_up_size); - auto opt_conv_op_parallel_config = conv2d::determine_conv_op_parallel_config_from_conv_output_mem_config( + auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( conv_out_memory_config, - conv2d::get_num_cores_nhw_from_parallel_config(parallel_config), - conv2d::get_num_cores_channels_from_parallel_config(parallel_config) + get_num_cores_nhw_from_parallel_config(parallel_config), + get_num_cores_channels_from_parallel_config(parallel_config) ); - auto opt_conv_op_block_config = conv2d::determine_per_core_conv_block_config( + auto opt_conv_op_block_config = determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, tt::round_up(in_channels, conv_config.input_channels_alignment), @@ -237,19 +240,59 @@ Result conv_transpose2d( std::optional bias_tensor_on_device = bias_tensor; if (!weight_is_on_device) { + // // prepare weights in desired layout and move to device + // tie(weight_tensor_on_device, bias_tensor_on_device) = conv2d::prepare_conv_weights_biases_and_move_to_device( + // transform_weights_for_conv_transpose2d(weight_tensor), + // bias_tensor, + // conv_config.input_channels_alignment, + // conv_config.weights_dtype, + // opt_conv_op_block_config.act_block_w_ntiles, + // opt_conv_op_block_config.out_subblock_w_ntiles, + // parallel_config, + // device, + // groups, + // opt_conv_op_block_config.act_block_h_ntiles, + // input_width); + // prepare weights in desired layout and move to device - tie(weight_tensor_on_device, bias_tensor_on_device) = conv2d::prepare_conv_weights_biases_and_move_to_device( + weight_tensor_on_device = prepare_conv_weights( transform_weights_for_conv_transpose2d(weight_tensor), - bias_tensor, - conv_config.input_channels_alignment, - conv_config.weights_dtype, - opt_conv_op_block_config.act_block_w_ntiles, - opt_conv_op_block_config.out_subblock_w_ntiles, - parallel_config, + input_tensor_post_tm.memory_config(), + "OIHW", + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, device, + conv_config + ); + weight_tensor_on_device = ttnn::operations::core::to_device(weight_tensor_on_device, device, std::nullopt); + } + + if (bias_tensor.has_value() && !ttnn::is_tensor_on_device_or_multidevice(bias_tensor.value())) { + bias_tensor_on_device = prepare_conv_bias( + bias_tensor.value(), + input_tensor_post_tm.memory_config(), + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, groups, - opt_conv_op_block_config.act_block_h_ntiles, - input_width); + device, + conv_config + ); + bias_tensor_on_device = ttnn::operations::core::to_device(bias_tensor_on_device.value(), device, std::nullopt); } // call conv micro op auto conv_output = optimized_conv_new( @@ -292,7 +335,7 @@ Result ConvTranpose2dOperation::invoke( std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config_){ + std::optional conv_config_){ return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config_); } @@ -313,7 +356,7 @@ Result ConvTranpose2dOperation::invoke( std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config_){ + std::optional conv_config_){ return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config_); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp index 0b38317b601..8e7a1f02aec 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp @@ -4,7 +4,7 @@ #pragma once #include -#include "ttnn/operations/conv/conv2d/conv2d.hpp" +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" namespace ttnn { @@ -14,6 +14,7 @@ namespace conv_transpose2d { using OutputHeight = uint32_t; using OutputWidth = uint32_t; using Result = std::tuple>; + struct ConvTranpose2dOperation{ static Result invoke( uint8_t queue_id, @@ -32,7 +33,7 @@ struct ConvTranpose2dOperation{ std::array dilation, uint32_t groups, std::optional bias_tensor = std::nullopt, - std::optional conv_config_ = std::nullopt); + std::optional conv_config_ = std::nullopt); static Result invoke( uint8_t queue_id, @@ -51,7 +52,7 @@ struct ConvTranpose2dOperation{ std::array dilation, uint32_t groups, std::optional bias_tensor = std::nullopt, - std::optional conv_config_ = std::nullopt); + std::optional conv_config_ = std::nullopt); }; } // namespace conv_transpose2d diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp index 8128ea7423b..08654aedea9 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp @@ -103,7 +103,7 @@ void py_bind_conv_transpose2d(py::module& module) { std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config, + std::optional conv_config, const uint8_t& queue_id) -> Result { return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config); }, @@ -142,7 +142,7 @@ void py_bind_conv_transpose2d(py::module& module) { std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config, + std::optional conv_config, const uint8_t& queue_id) -> Result { return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config); }, diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp index bf2cf007f15..d05333aece7 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp @@ -5,7 +5,7 @@ #include "max_pool2d.hpp" #include "impl/buffers/buffer_constants.hpp" -#include "ttnn/operations/conv/conv2d/conv2d.hpp" +#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "tt_metal/common/math.hpp" #include "ttnn/common/constants.hpp" @@ -54,7 +54,7 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, "Only height, width, or block sharding strategies are supported."); shard_layout = applied_shard_scheme.value(); } - parallel_config = conv::conv2d::determine_parallel_config( + parallel_config = conv::determine_parallel_config( shard_layout, batch_size, channels, @@ -64,9 +64,9 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, input_tensor.device()->compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR, false); - num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); - num_cores_c = conv::conv2d::get_num_cores_channels_from_parallel_config(parallel_config); - auto sharded_mem_config = conv::conv2d::create_sharded_memory_config_from_parallel_config(input_tensor_sharded.shape(), parallel_config, is_in_tiled ? tt::constants::TILE_HEIGHT : 1); + num_cores_nhw = conv::get_num_cores_nhw_from_parallel_config(parallel_config); + num_cores_c = conv::get_num_cores_channels_from_parallel_config(parallel_config); + auto sharded_mem_config = conv::create_sharded_memory_config_from_parallel_config(input_tensor_sharded.shape(), parallel_config, is_in_tiled ? tt::constants::TILE_HEIGHT : 1); input_tensor_sharded = ttnn::to_memory_config(input_tensor_sharded, sharded_mem_config, std::nullopt); // this converts interleaved to sharded out_memory_config = input_tensor_sharded.memory_config(); } else { @@ -79,8 +79,8 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, parallel_config.grid = shard_grid; parallel_config.shard_scheme = shard_scheme; parallel_config.shard_orientation = shard_orientation; - num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); - num_cores_c = conv::conv2d::get_num_cores_channels_from_parallel_config(parallel_config); + num_cores_nhw = conv::get_num_cores_nhw_from_parallel_config(parallel_config); + num_cores_c = conv::get_num_cores_channels_from_parallel_config(parallel_config); } // update the shard spec to match the output shape diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index fba363c3971..097b7ea7fb0 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -295,7 +295,14 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL Topology, ) -from ttnn.operations.conv2d import Conv2dConfig, get_conv_padded_input_shape_and_mem_config, get_conv_output_dim + +from ttnn.operations.conv2d import ( + Conv2dConfig, + get_conv_output_dim, + get_conv_padded_input_shape_and_mem_config, + prepare_conv_weights, + prepare_conv_bias, +) from ttnn.operations.pool import avg_pool2d from ttnn.operations.conv1d import Conv1d, Conv1dConfig diff --git a/ttnn/ttnn/operations/conv1d.py b/ttnn/ttnn/operations/conv1d.py index e979a12b21d..c009dea1bc4 100644 --- a/ttnn/ttnn/operations/conv1d.py +++ b/ttnn/ttnn/operations/conv1d.py @@ -30,6 +30,8 @@ def Conv1d( conv_config: Conv1dConfig = None, # config overrides by user conv_op_cache={}, # basic conv object caching in python needed for intermediate refactoring. Not needed after full op refactoring in C++. debug=False, + return_output_length=False, + return_prepared_device_weights=False, ) -> Tuple[ttnn.Tensor, int, int, ttnn.Tensor, ttnn.Tensor]: # Reshape the input and weight tensors to 4D for conv2d operation # Should be no-op as input_tensor is in RM layout @@ -62,12 +64,14 @@ def Conv1d( conv_config=conv_config, ) - return ( - output_tensor_new, - output_length_new, - weight_tensor_on_dev_new, - bias_tensor_on_dev_new, - ) + if return_output_length and return_prepared_device_weights: + return output_tensor_new, output_length_new, weight_tensor_on_dev_new, bias_tensor_on_dev_new + elif return_prepared_device_weights: + return output_tensor_new, weight_tensor_on_dev_new, bias_tensor_on_dev_new + elif return_output_length: + return output_tensor_new, output_length_new + else: + return output_tensor_new __all__ = [] diff --git a/ttnn/ttnn/operations/conv2d.py b/ttnn/ttnn/operations/conv2d.py index ca1f329dd69..edb3ccd575c 100644 --- a/ttnn/ttnn/operations/conv2d.py +++ b/ttnn/ttnn/operations/conv2d.py @@ -33,6 +33,78 @@ def get_conv_output_dim(input, window, stride=1, pad=0, dilation=1): return (input + (2 * pad) - dilation * (window - 1) - 1) // stride + 1 +def prepare_conv_weights( + *, + weight_tensor, + input_memory_config, + weights_format, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + device, + conv_config=None, +): + return ttnn._ttnn.operations.conv2d.prepare_conv_weights( + weight_tensor=weight_tensor, + input_memory_config=input_memory_config, + weights_format=weights_format, + in_channels=in_channels, + out_channels=out_channels, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + kernel_size=list(kernel_size), + stride=list(stride), + padding=list(padding), + dilation=list(dilation), + groups=groups, + device=device, + conv_config=conv_config, + ) + + +def prepare_conv_bias( + *, + bias_tensor, + input_memory_config, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + device, + conv_config=None, +): + return ttnn._ttnn.operations.conv2d.prepare_conv_bias( + bias_tensor=bias_tensor, + input_memory_config=input_memory_config, + in_channels=in_channels, + out_channels=out_channels, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + kernel_size=list(kernel_size), + stride=list(stride), + padding=list(padding), + dilation=list(dilation), + groups=groups, + device=device, + conv_config=conv_config, + ) + + def convert_conv_weight_tensor_to_tiled_layout(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype=None): """ Converts convolution weights to 2d matrix tiled layout on host @@ -104,8 +176,16 @@ def conv2d( memory_config: ttnn.MemoryConfig = None, # memory config overrides by user conv_op_cache={}, # basic conv object caching in python needed for intermediate refactoring. Not needed after full op refactoring in C++. debug=False, # ignored + return_output_size=False, + return_prepared_device_weights=False, ) -> Tuple[ttnn.Tensor, int, int, ttnn.Tensor, ttnn.Tensor]: - return ttnn._ttnn.operations.conv.conv2d( + ( + conv_output, + output_height, + output_width, + prepared_device_weight, + prepared_device_bias, + ) = ttnn._ttnn.operations.conv2d.conv2d( input_tensor=input_tensor, weight_tensor=weight_tensor, device=device, @@ -124,5 +204,14 @@ def conv2d( memory_config=memory_config, ) + if return_output_size and return_prepared_device_weights: + return conv_output, output_height, output_width, prepared_device_weight, prepared_device_bias + elif return_prepared_device_weights: + return conv_output, prepared_device_weight, prepared_device_bias + elif return_output_size: + return conv_output, output_height, output_width + else: + return conv_output + __all__ = [] diff --git a/ttnn/tutorials/006.ipynb b/ttnn/tutorials/006.ipynb index 9550d9cc24e..0bda919b3a4 100644 --- a/ttnn/tutorials/006.ipynb +++ b/ttnn/tutorials/006.ipynb @@ -950,6 +950,8 @@ " input_width=self.input_shape[2],\n", " conv_config=conv_config,\n", " groups=self.groups,\n", + " return_output_size=True,\n", + " return_prepared_device_weights=True\n", " )\n", "\n", " return output_tensor\n",