Skip to content

Commit

Permalink
feat: grouped conv1d (#1749)
Browse files Browse the repository at this point in the history
* feat: implement grouped convolution for CPU

* feat: implement grouped Conv1D for DNNL

* feat: implement grouped Conv1D for CUDA
  • Loading branch information
ebraraktas authored Aug 13, 2024
1 parent a386cbd commit 1000086
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 30 deletions.
3 changes: 2 additions & 1 deletion include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ namespace ctranslate2 {
const std::string& scope,
dim_t stride = 1,
dim_t padding = 0,
dim_t dilation = 1);
dim_t dilation = 1,
dim_t groups = 1);
DataType output_type() const override;
dim_t output_size() const override;
dim_t input_size() const;
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace ctranslate2 {

class Conv1D : public Op {
public:
Conv1D(dim_t stride = 1, dim_t padding = 0, dim_t dilation = 1);
Conv1D(dim_t stride = 1, dim_t padding = 0, dim_t dilation = 1, dim_t groups=1);

void operator()(const StorageView& input,
const StorageView& weight,
Expand All @@ -25,6 +25,7 @@ namespace ctranslate2 {
dim_t _stride;
dim_t _padding;
dim_t _dilation;
dim_t _groups;

void operator()(const StorageView& input,
const StorageView& weight,
Expand Down
5 changes: 3 additions & 2 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,9 @@ namespace ctranslate2 {
const std::string& scope,
dim_t stride,
dim_t padding,
dim_t dilation)
: _conv_op(stride, padding, dilation)
dim_t dilation,
dim_t groups)
: _conv_op(stride, padding, dilation, groups)
, _weight(model.get_variable(scope + "/weight"))
, _bias(model.get_variable_if_exists(scope + "/bias"))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale")) {
Expand Down
3 changes: 2 additions & 1 deletion src/ops/conv1d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
namespace ctranslate2 {
namespace ops {

Conv1D::Conv1D(dim_t stride, dim_t padding, dim_t dilation)
Conv1D::Conv1D(dim_t stride, dim_t padding, dim_t dilation, dim_t groups)
: _stride(stride)
, _padding(padding)
, _dilation(dilation)
, _groups(groups)
{
}

Expand Down
61 changes: 37 additions & 24 deletions src/ops/conv1d_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ctranslate2 {

dnnl::memory::dims input_dims(input.shape().begin(), input.shape().end());
dnnl::memory::dims output_dims(output.shape().begin(), output.shape().end());
dnnl::memory::dims weight_dims(weight.shape().begin(), weight.shape().end());
dnnl::memory::dims weight_dims{_groups, weight.dim(0) / _groups, weight.dim(1), weight.dim(2)};

using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;
Expand All @@ -32,7 +32,7 @@ namespace ctranslate2 {
const_cast<void*>(input.buffer()));
dnnl::memory output_mem({output_dims, dt::f32, tag::ncw}, engine,
output.buffer());
dnnl::memory weight_mem({weight_dims, dt::f32, tag::oiw}, engine,
dnnl::memory weight_mem({weight_dims, dt::f32, tag::goiw}, engine,
const_cast<void*>(weight.buffer()));

dnnl::memory::dims stride{_stride};
Expand Down Expand Up @@ -160,9 +160,10 @@ namespace ctranslate2 {
const dim_t out_channels = weight.dim(0);
const dim_t kernel_size = weight.dim(2);
const dim_t output_length = output.dim(2);
const dim_t in_channels_per_group = in_channels / _groups;

// Create im2col_output tensor.
// im2col_output shape is (batch_size, out_length, in_channels * kernel_size).
// im2col_output shape is (batch_size, groups, out_length, in_channels_per_group * kernel_size).
// This is necessary for quantization:
// * we need to run GEMM as (weight x im2col_output) to avoid extra copies
// * input (RHS) must be quantized along columns, to dequantize later
Expand All @@ -171,17 +172,19 @@ namespace ctranslate2 {
// * We can use qinput_scale generated from rows of this im2col_output, as they correspond
// to columns of the multiplied shape (because of transpose).
// we can use same matrix for FLOAT32 computation, too.
StorageView im2col_output({batch_size, output_length, in_channels * kernel_size}, 0.0f, weight.device());
StorageView im2col_output({batch_size, _groups, output_length, in_channels_per_group * kernel_size}, 0.0f, weight.device());
im2col_transposed(input, im2col_output, kernel_size);
// Create a 2D view of weight to use in GEMM
StorageView weight_view(weight.dtype(), weight.device());
weight_view.view(const_cast<void*>(weight.buffer()), {weight.dim(0), in_channels * kernel_size});

const dim_t m = out_channels;
// out: bs x (group * out_per_group) x out_len
const dim_t m = out_channels / _groups;
const dim_t n = output_length;
const dim_t k = in_channels * kernel_size;
const dim_t k = in_channels_per_group * kernel_size;
// stridew in bytes to handle quantized weights
const dim_t stridew = out_channels / _groups * in_channels_per_group * kernel_size * weight.item_size();
const dim_t strideb = k * output_length;
const dim_t stridec = out_channels * output_length;
const dim_t stridec = m * output_length;
const dim_t qscale_stride = qscale ? qscale->dim(0) / _groups : 0;
// Create byte pointer to the weights, and we will use group weight slice in the loop below
auto* w = static_cast<int8_t*>(const_cast<void*>(weight.buffer()));
auto* b = im2col_output.data<float>();
auto* c = output.data<float>();
const Gemm gemm(1.0, 0.0, false, true);
Expand All @@ -190,54 +193,64 @@ namespace ctranslate2 {
/*round_before_cast=*/true);
const Dequantize dequantize_op;
const auto device = im2col_output.device();
cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
cpu::parallel_for(0, batch_size * _groups, 1, [&](dim_t begin, dim_t end) {
StorageView qinput(weight.dtype(), device);
StorageView qinput_scale(device);
if (qscale)
qinput_scale.to(qscale->dtype());
StorageView qoutput(DataType::INT32, device);
for (dim_t i = begin; i < end; ++i) {
auto group_index = i % _groups;
void* w_i = w + (group_index * stridew);
float* b_i = b + (i * strideb);
float* c_i = c + (i * stridec);
// Create a 2D view of group weights to use in GEMM
StorageView aa(weight.dtype(), weight.device());
aa.view(w_i, {m, k});
StorageView bb({n, k}, b_i); // transposed
StorageView cc({m, n}, c_i);

if (qscale) {
StorageView group_qscale({qscale_stride}, const_cast<float *>(qscale->data<float>()) + group_index * qscale_stride);
quantize_op(bb, qinput, qinput_scale);
gemm(weight_view, qinput, qoutput);
gemm(aa, qinput, qoutput);
dequantize_op(qoutput,
*qscale,
group_qscale,
qinput_scale,
/*trans_a=*/false,
/*trans_b=*/true,
cc);
} else {
gemm(weight_view, bb, cc);
gemm(aa, bb, cc);
}
}
});
}

void Conv1D::im2col_transposed(const StorageView& input, StorageView& output, const dim_t kernel_size) const {
// input: batch_size x in_channels x input_length
// output: batch_size x output_length x (in_channels * kernel_size)
// output: batch_size x groups x output_length x (in_channels_per_group * kernel_size)
const dim_t batch_size = input.dim(0);
const dim_t in_channels = input.dim(1);
const dim_t input_length = input.dim(2);
auto* out = output.data <float>();
const auto* in = input.data <float>();
dim_t out_offset = 0;
const auto in_batch_stride = in_channels * input_length;
const dim_t in_channels_per_group = in_channels / _groups;
const dim_t in_group_stride = in_channels_per_group * input_length;
for (dim_t batch_offset = 0; batch_offset < batch_size * in_batch_stride; batch_offset += in_batch_stride) {
for (int ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) {
for (dim_t c = batch_offset; c < (batch_offset + in_channels * input_length); c += input_length) {
for (int k = 0; k < kernel_size; k++) {
// Fill items in [0, input_length) range
auto window_i = k + ti;
if (0 <= window_i && window_i < input_length) {
out[out_offset] = in[window_i + c];
for (dim_t group_offset = batch_offset; group_offset < (batch_offset + _groups * in_group_stride); group_offset += in_group_stride) {
for (dim_t ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) {
for (dim_t c = group_offset; c < (group_offset + in_channels_per_group * input_length); c += input_length) {
for (int k = 0; k < kernel_size; k++) {
// Fill items in [0, input_length) range
auto window_i = k + ti;
if (0 <= window_i && window_i < input_length) {
out[out_offset] = in[window_i + c];
}
out_offset += 1;
}
out_offset += 1;
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/ops/conv1d_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace ctranslate2 {
const int input_length = input.dim(2);
const int output_length = output.dim(2);
const int out_channels = weight.dim(0);
const int in_channels_per_group = weight.dim(1);
const int kernel_size = weight.dim(2);

cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype());
Expand All @@ -45,7 +46,7 @@ namespace ctranslate2 {
cudnnFilterDescriptor_t weight_desc;
CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW,
out_channels, in_channels, 1, kernel_size));
out_channels, in_channels_per_group, 1, kernel_size));

cudnnConvolutionDescriptor_t conv_desc;
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
Expand All @@ -57,6 +58,8 @@ namespace ctranslate2 {
CUDNN_DATA_FLOAT));

CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));
if (_groups > 1)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, _groups));
if (data_type == CUDNN_DATA_HALF)
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));

Expand Down
78 changes: 78 additions & 0 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,84 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) {
expect_storage_eq(output.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, Conv1DGroupNoBias) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
const StorageView expected({2, 2, 2}, std::vector<float>{
-0.475623f, -0.601933f, 0.165541f, 0.050849f, -0.566024f,
-0.592437f, 0.121356f, 0.232157f});
const StorageView conv_input({2, 4, 4}, std::vector<float>{
0.547210f, 0.634821f, 0.571043f, 0.443073f, 0.220554f, 0.478427f,
0.836031f, 0.476906f, 0.288942f, 0.393840f, 0.077658f, 0.236493f,
0.759209f, 0.826134f, 0.728944f, 0.130438f, 0.355182f, 0.884368f,
0.494477f, 0.004999f, 0.306053f, 0.764639f, 0.903179f, 0.440537f,
0.040332f, 0.533495f, 0.428653f, 0.311188f, 0.951956f, 0.785873f,
0.443364f, 0.065968f});
const StorageView conv_weight({2, 2, 3}, std::vector<float>{
-0.326986f, -0.378711f, -0.120962f, 0.125665f, -0.312741f, 0.161123f,
0.226274f, 0.340959f, -0.127573f, 0.094374f, -0.164143f, 0.054516f});
StorageView output(dtype, device);
ops::Conv1D(1, 0, 1, 2)(conv_input.to(device).to(dtype),
conv_weight.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, Conv1DGroupNoBiasQuantized) {
#ifdef CT2_WITH_DNNL
GTEST_SKIP() << "Quantized convolution is not implemented for DNNL.";
#endif
const Device device = GetParam().device;
if (device != Device::CPU)
GTEST_SKIP() << "Grouped quantized convolution is not implemented for CUDA.";
const DataType dtype = GetParam().dtype;
const float error = std::max(GetParam().error, float(3e-3));
const StorageView expected({2, 2, 2}, std::vector<float>{
-0.475623f, -0.601933f, 0.165541f, 0.050849f, -0.566024f,
-0.592437f, 0.121356f, 0.232157f});
const StorageView conv_input({2, 4, 4}, std::vector<float>{
0.547210f, 0.634821f, 0.571043f, 0.443073f, 0.220554f, 0.478427f,
0.836031f, 0.476906f, 0.288942f, 0.393840f, 0.077658f, 0.236493f,
0.759209f, 0.826134f, 0.728944f, 0.130438f, 0.355182f, 0.884368f,
0.494477f, 0.004999f, 0.306053f, 0.764639f, 0.903179f, 0.440537f,
0.040332f, 0.533495f, 0.428653f, 0.311188f, 0.951956f, 0.785873f,
0.443364f, 0.065968f});
// These weights correspond to the ones in Conv1DGroupNoBias
// Hence expected output is same (with quantization error)
// Therefore we use error = 3e-3
const StorageView conv_weight({2, 2, 3}, std::vector<int8_t>{
-110, -127, -41, 42, -105, 54, 84, 127, -48, 35, -61, 20});
const StorageView conv_qscale({2}, std::vector<float> {335.34806224, 372.47880244});
StorageView output(dtype, device);
ops::Conv1D(1, 0, 1, 2)(conv_input.to(device).to(dtype),
conv_weight.to(device),
output,
&conv_qscale);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, Conv1DGroup) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
const StorageView expected({2, 2, 2}, std::vector<float>{
0.142335f, 0.103515f, 0.735452f, 0.755268f, 0.109328f, 0.007098f, 0.791004f, 0.537695f});
const StorageView conv_input({2, 4, 4}, std::vector<float>{
0.769843f, 0.147572f, 0.195656f, 0.823936f, 0.363211f, 0.584773f, 0.315626f, 0.929829f, 0.724258f, 0.853388f, 0.756254f, 0.791604f, 0.463644f, 0.285105f, 0.952018f, 0.660709f, 0.557387f, 0.147298f, 0.473786f, 0.566577f, 0.255724f, 0.488177f, 0.534283f, 0.678067f, 0.760340f, 0.024571f, 0.559195f, 0.978376f, 0.473044f, 0.351244f, 0.824801f, 0.077629f});
const StorageView conv_weight({2, 2, 3}, std::vector<float>{
0.345985f, -0.071498f, 0.200554f, 0.185144f, -0.015271f, 0.014293f, 0.006771f, -0.078667f, -0.065937f, 0.382823f, 0.276695f, 0.352038f});
const StorageView conv_bias({2}, std::vector<float>{-0.215535f, 0.256019f});
StorageView output(dtype, device);
ops::Conv1D(1, 0, 1, 2)(conv_input.to(device).to(dtype),
conv_weight.to(device).to(dtype),
conv_bias.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float32(), expected, error);
}

INSTANTIATE_TEST_SUITE_P(CPU, OpDeviceTest, ::testing::Values(Device::CPU));
INSTANTIATE_TEST_SUITE_P(CPU, OpDeviceFPTest,
Expand Down

0 comments on commit 1000086

Please sign in to comment.