diff --git a/projects_oss/detr/detr/src/gpu/ms_deform_attn_cuda.cu b/projects_oss/detr/detr/src/gpu/ms_deform_attn_cuda.cu index d21b4c8a..d84913ad 100644 --- a/projects_oss/detr/detr/src/gpu/ms_deform_attn_cuda.cu +++ b/projects_oss/detr/detr/src/gpu/ms_deform_attn_cuda.cu @@ -61,7 +61,7 @@ at::Tensor ms_deform_attn_cuda_forward( for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", [&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), @@ -71,7 +71,7 @@ at::Tensor ms_deform_attn_cuda_forward( batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); - })); + }); } output = output.view({batch, num_query, num_heads*channels}); @@ -131,7 +131,7 @@ std::vector ms_deform_attn_cuda_backward( for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", [&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, @@ -144,7 +144,7 @@ std::vector ms_deform_attn_cuda_backward( grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); - })); + }); } return {