From 10000866505d056357fe39581ee159701ef54ec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mustafa=20Ebrar=20Akta=C5=9F?= Date: Tue, 13 Aug 2024 10:45:55 +0300 Subject: [PATCH] feat: grouped conv1d (#1749) * feat: implement grouped convolution for CPU * feat: implement grouped Conv1D for DNNL * feat: implement grouped Conv1D for CUDA --- include/ctranslate2/layers/common.h | 3 +- include/ctranslate2/ops/conv1d.h | 3 +- src/layers/common.cc | 5 +- src/ops/conv1d.cc | 3 +- src/ops/conv1d_cpu.cc | 61 +++++++++++++--------- src/ops/conv1d_gpu.cu | 5 +- tests/ops_test.cc | 78 +++++++++++++++++++++++++++++ 7 files changed, 128 insertions(+), 30 deletions(-) diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 3985b3feb..137b926d3 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -174,7 +174,8 @@ namespace ctranslate2 { const std::string& scope, dim_t stride = 1, dim_t padding = 0, - dim_t dilation = 1); + dim_t dilation = 1, + dim_t groups = 1); DataType output_type() const override; dim_t output_size() const override; dim_t input_size() const; diff --git a/include/ctranslate2/ops/conv1d.h b/include/ctranslate2/ops/conv1d.h index fc37021d0..2b18d5632 100644 --- a/include/ctranslate2/ops/conv1d.h +++ b/include/ctranslate2/ops/conv1d.h @@ -8,7 +8,7 @@ namespace ctranslate2 { class Conv1D : public Op { public: - Conv1D(dim_t stride = 1, dim_t padding = 0, dim_t dilation = 1); + Conv1D(dim_t stride = 1, dim_t padding = 0, dim_t dilation = 1, dim_t groups=1); void operator()(const StorageView& input, const StorageView& weight, @@ -25,6 +25,7 @@ namespace ctranslate2 { dim_t _stride; dim_t _padding; dim_t _dilation; + dim_t _groups; void operator()(const StorageView& input, const StorageView& weight, diff --git a/src/layers/common.cc b/src/layers/common.cc index 86fb66a7d..c6d1cd0b5 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -467,8 +467,9 @@ namespace ctranslate2 { const std::string& scope, dim_t stride, dim_t padding, - dim_t dilation) - : _conv_op(stride, padding, dilation) + dim_t dilation, + dim_t groups) + : _conv_op(stride, padding, dilation, groups) , _weight(model.get_variable(scope + "/weight")) , _bias(model.get_variable_if_exists(scope + "/bias")) , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) { diff --git a/src/ops/conv1d.cc b/src/ops/conv1d.cc index bde97dc8f..02ca6c72d 100644 --- a/src/ops/conv1d.cc +++ b/src/ops/conv1d.cc @@ -5,10 +5,11 @@ namespace ctranslate2 { namespace ops { - Conv1D::Conv1D(dim_t stride, dim_t padding, dim_t dilation) + Conv1D::Conv1D(dim_t stride, dim_t padding, dim_t dilation, dim_t groups) : _stride(stride) , _padding(padding) , _dilation(dilation) + , _groups(groups) { } diff --git a/src/ops/conv1d_cpu.cc b/src/ops/conv1d_cpu.cc index a45388b57..fc9d82f7c 100644 --- a/src/ops/conv1d_cpu.cc +++ b/src/ops/conv1d_cpu.cc @@ -19,7 +19,7 @@ namespace ctranslate2 { dnnl::memory::dims input_dims(input.shape().begin(), input.shape().end()); dnnl::memory::dims output_dims(output.shape().begin(), output.shape().end()); - dnnl::memory::dims weight_dims(weight.shape().begin(), weight.shape().end()); + dnnl::memory::dims weight_dims{_groups, weight.dim(0) / _groups, weight.dim(1), weight.dim(2)}; using tag = dnnl::memory::format_tag; using dt = dnnl::memory::data_type; @@ -32,7 +32,7 @@ namespace ctranslate2 { const_cast(input.buffer())); dnnl::memory output_mem({output_dims, dt::f32, tag::ncw}, engine, output.buffer()); - dnnl::memory weight_mem({weight_dims, dt::f32, tag::oiw}, engine, + dnnl::memory weight_mem({weight_dims, dt::f32, tag::goiw}, engine, const_cast(weight.buffer())); dnnl::memory::dims stride{_stride}; @@ -160,9 +160,10 @@ namespace ctranslate2 { const dim_t out_channels = weight.dim(0); const dim_t kernel_size = weight.dim(2); const dim_t output_length = output.dim(2); + const dim_t in_channels_per_group = in_channels / _groups; // Create im2col_output tensor. - // im2col_output shape is (batch_size, out_length, in_channels * kernel_size). + // im2col_output shape is (batch_size, groups, out_length, in_channels_per_group * kernel_size). // This is necessary for quantization: // * we need to run GEMM as (weight x im2col_output) to avoid extra copies // * input (RHS) must be quantized along columns, to dequantize later @@ -171,17 +172,19 @@ namespace ctranslate2 { // * We can use qinput_scale generated from rows of this im2col_output, as they correspond // to columns of the multiplied shape (because of transpose). // we can use same matrix for FLOAT32 computation, too. - StorageView im2col_output({batch_size, output_length, in_channels * kernel_size}, 0.0f, weight.device()); + StorageView im2col_output({batch_size, _groups, output_length, in_channels_per_group * kernel_size}, 0.0f, weight.device()); im2col_transposed(input, im2col_output, kernel_size); - // Create a 2D view of weight to use in GEMM - StorageView weight_view(weight.dtype(), weight.device()); - weight_view.view(const_cast(weight.buffer()), {weight.dim(0), in_channels * kernel_size}); - - const dim_t m = out_channels; + // out: bs x (group * out_per_group) x out_len + const dim_t m = out_channels / _groups; const dim_t n = output_length; - const dim_t k = in_channels * kernel_size; + const dim_t k = in_channels_per_group * kernel_size; + // stridew in bytes to handle quantized weights + const dim_t stridew = out_channels / _groups * in_channels_per_group * kernel_size * weight.item_size(); const dim_t strideb = k * output_length; - const dim_t stridec = out_channels * output_length; + const dim_t stridec = m * output_length; + const dim_t qscale_stride = qscale ? qscale->dim(0) / _groups : 0; + // Create byte pointer to the weights, and we will use group weight slice in the loop below + auto* w = static_cast(const_cast(weight.buffer())); auto* b = im2col_output.data(); auto* c = output.data(); const Gemm gemm(1.0, 0.0, false, true); @@ -190,29 +193,35 @@ namespace ctranslate2 { /*round_before_cast=*/true); const Dequantize dequantize_op; const auto device = im2col_output.device(); - cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { + cpu::parallel_for(0, batch_size * _groups, 1, [&](dim_t begin, dim_t end) { StorageView qinput(weight.dtype(), device); StorageView qinput_scale(device); if (qscale) qinput_scale.to(qscale->dtype()); StorageView qoutput(DataType::INT32, device); for (dim_t i = begin; i < end; ++i) { + auto group_index = i % _groups; + void* w_i = w + (group_index * stridew); float* b_i = b + (i * strideb); float* c_i = c + (i * stridec); + // Create a 2D view of group weights to use in GEMM + StorageView aa(weight.dtype(), weight.device()); + aa.view(w_i, {m, k}); StorageView bb({n, k}, b_i); // transposed StorageView cc({m, n}, c_i); if (qscale) { + StorageView group_qscale({qscale_stride}, const_cast(qscale->data()) + group_index * qscale_stride); quantize_op(bb, qinput, qinput_scale); - gemm(weight_view, qinput, qoutput); + gemm(aa, qinput, qoutput); dequantize_op(qoutput, - *qscale, + group_qscale, qinput_scale, /*trans_a=*/false, /*trans_b=*/true, cc); } else { - gemm(weight_view, bb, cc); + gemm(aa, bb, cc); } } }); @@ -220,7 +229,7 @@ namespace ctranslate2 { void Conv1D::im2col_transposed(const StorageView& input, StorageView& output, const dim_t kernel_size) const { // input: batch_size x in_channels x input_length - // output: batch_size x output_length x (in_channels * kernel_size) + // output: batch_size x groups x output_length x (in_channels_per_group * kernel_size) const dim_t batch_size = input.dim(0); const dim_t in_channels = input.dim(1); const dim_t input_length = input.dim(2); @@ -228,16 +237,20 @@ namespace ctranslate2 { const auto* in = input.data (); dim_t out_offset = 0; const auto in_batch_stride = in_channels * input_length; + const dim_t in_channels_per_group = in_channels / _groups; + const dim_t in_group_stride = in_channels_per_group * input_length; for (dim_t batch_offset = 0; batch_offset < batch_size * in_batch_stride; batch_offset += in_batch_stride) { - for (int ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) { - for (dim_t c = batch_offset; c < (batch_offset + in_channels * input_length); c += input_length) { - for (int k = 0; k < kernel_size; k++) { - // Fill items in [0, input_length) range - auto window_i = k + ti; - if (0 <= window_i && window_i < input_length) { - out[out_offset] = in[window_i + c]; + for (dim_t group_offset = batch_offset; group_offset < (batch_offset + _groups * in_group_stride); group_offset += in_group_stride) { + for (dim_t ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) { + for (dim_t c = group_offset; c < (group_offset + in_channels_per_group * input_length); c += input_length) { + for (int k = 0; k < kernel_size; k++) { + // Fill items in [0, input_length) range + auto window_i = k + ti; + if (0 <= window_i && window_i < input_length) { + out[out_offset] = in[window_i + c]; + } + out_offset += 1; } - out_offset += 1; } } } diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu index 6f4d10b39..11ba48a1f 100644 --- a/src/ops/conv1d_gpu.cu +++ b/src/ops/conv1d_gpu.cu @@ -28,6 +28,7 @@ namespace ctranslate2 { const int input_length = input.dim(2); const int output_length = output.dim(2); const int out_channels = weight.dim(0); + const int in_channels_per_group = weight.dim(1); const int kernel_size = weight.dim(2); cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype()); @@ -45,7 +46,7 @@ namespace ctranslate2 { cudnnFilterDescriptor_t weight_desc; CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc)); CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW, - out_channels, in_channels, 1, kernel_size)); + out_channels, in_channels_per_group, 1, kernel_size)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); @@ -57,6 +58,8 @@ namespace ctranslate2 { CUDNN_DATA_FLOAT)); CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); + if (_groups > 1) + CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, _groups)); if (data_type == CUDNN_DATA_HALF) CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH)); diff --git a/tests/ops_test.cc b/tests/ops_test.cc index e0615c052..7d7b376fa 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -1103,6 +1103,84 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) { expect_storage_eq(output.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, Conv1DGroupNoBias) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + const StorageView expected({2, 2, 2}, std::vector{ + -0.475623f, -0.601933f, 0.165541f, 0.050849f, -0.566024f, + -0.592437f, 0.121356f, 0.232157f}); + const StorageView conv_input({2, 4, 4}, std::vector{ + 0.547210f, 0.634821f, 0.571043f, 0.443073f, 0.220554f, 0.478427f, + 0.836031f, 0.476906f, 0.288942f, 0.393840f, 0.077658f, 0.236493f, + 0.759209f, 0.826134f, 0.728944f, 0.130438f, 0.355182f, 0.884368f, + 0.494477f, 0.004999f, 0.306053f, 0.764639f, 0.903179f, 0.440537f, + 0.040332f, 0.533495f, 0.428653f, 0.311188f, 0.951956f, 0.785873f, + 0.443364f, 0.065968f}); + const StorageView conv_weight({2, 2, 3}, std::vector{ + -0.326986f, -0.378711f, -0.120962f, 0.125665f, -0.312741f, 0.161123f, + 0.226274f, 0.340959f, -0.127573f, 0.094374f, -0.164143f, 0.054516f}); + StorageView output(dtype, device); + ops::Conv1D(1, 0, 1, 2)(conv_input.to(device).to(dtype), + conv_weight.to(device).to(dtype), + output); + EXPECT_EQ(output.dtype(), dtype); + expect_storage_eq(output.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) { +#ifdef CT2_WITH_DNNL + GTEST_SKIP() << "Quantized convolution is not implemented for DNNL."; +#endif + const Device device = GetParam().device; + if (device != Device::CPU) + GTEST_SKIP() << "Grouped quantized convolution is not implemented for CUDA."; + const DataType dtype = GetParam().dtype; + const float error = std::max(GetParam().error, float(3e-3)); + const StorageView expected({2, 2, 2}, std::vector{ + -0.475623f, -0.601933f, 0.165541f, 0.050849f, -0.566024f, + -0.592437f, 0.121356f, 0.232157f}); + const StorageView conv_input({2, 4, 4}, std::vector{ + 0.547210f, 0.634821f, 0.571043f, 0.443073f, 0.220554f, 0.478427f, + 0.836031f, 0.476906f, 0.288942f, 0.393840f, 0.077658f, 0.236493f, + 0.759209f, 0.826134f, 0.728944f, 0.130438f, 0.355182f, 0.884368f, + 0.494477f, 0.004999f, 0.306053f, 0.764639f, 0.903179f, 0.440537f, + 0.040332f, 0.533495f, 0.428653f, 0.311188f, 0.951956f, 0.785873f, + 0.443364f, 0.065968f}); + // These weights correspond to the ones in Conv1DGroupNoBias + // Hence expected output is same (with quantization error) + // Therefore we use error = 3e-3 + const StorageView conv_weight({2, 2, 3}, std::vector{ + -110, -127, -41, 42, -105, 54, 84, 127, -48, 35, -61, 20}); + const StorageView conv_qscale({2}, std::vector {335.34806224, 372.47880244}); + StorageView output(dtype, device); + ops::Conv1D(1, 0, 1, 2)(conv_input.to(device).to(dtype), + conv_weight.to(device), + output, + &conv_qscale); + EXPECT_EQ(output.dtype(), dtype); + expect_storage_eq(output.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, Conv1DGroup) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + const StorageView expected({2, 2, 2}, std::vector{ + 0.142335f, 0.103515f, 0.735452f, 0.755268f, 0.109328f, 0.007098f, 0.791004f, 0.537695f}); + const StorageView conv_input({2, 4, 4}, std::vector{ + 0.769843f, 0.147572f, 0.195656f, 0.823936f, 0.363211f, 0.584773f, 0.315626f, 0.929829f, 0.724258f, 0.853388f, 0.756254f, 0.791604f, 0.463644f, 0.285105f, 0.952018f, 0.660709f, 0.557387f, 0.147298f, 0.473786f, 0.566577f, 0.255724f, 0.488177f, 0.534283f, 0.678067f, 0.760340f, 0.024571f, 0.559195f, 0.978376f, 0.473044f, 0.351244f, 0.824801f, 0.077629f}); + const StorageView conv_weight({2, 2, 3}, std::vector{ + 0.345985f, -0.071498f, 0.200554f, 0.185144f, -0.015271f, 0.014293f, 0.006771f, -0.078667f, -0.065937f, 0.382823f, 0.276695f, 0.352038f}); + const StorageView conv_bias({2}, std::vector{-0.215535f, 0.256019f}); + StorageView output(dtype, device); + ops::Conv1D(1, 0, 1, 2)(conv_input.to(device).to(dtype), + conv_weight.to(device).to(dtype), + conv_bias.to(device).to(dtype), + output); + EXPECT_EQ(output.dtype(), dtype); + expect_storage_eq(output.to_float32(), expected, error); +} INSTANTIATE_TEST_SUITE_P(CPU, OpDeviceTest, ::testing::Values(Device::CPU)); INSTANTIATE_TEST_SUITE_P(CPU, OpDeviceFPTest,