Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster SpatialDepthWiseConvolution #481

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

shrubb
Copy link

@shrubb shrubb commented Aug 4, 2017

The current stub implementation is totally impractical; the fastest GPU depthwise conv for Torch was cuDNN's grouped conv with self.groups == self.nInputPlane. Still, even with cuDNN, Google's MobileNets for example would run 2 times slower than ResNet-34.

I've tried lots of lots of lots of methods to efficiently reduce this stuff to cublas routines: using gemmBatched, grouping channels for heavier gemms load etc. Unfortunately, then I've only managed to roughly reach cuDNN' performance in backward pass and make a 1.5x speedup in MobileNet forward pass.

Surprisingly, the fastest option by far turned to be...the super dumb for loop. Here it is (with a bit smarter option for accGradParams though). The forward/backward passes are now at least 45x/8x faster than the original implementation respectively. Default MobileNet's inference enjoys a speedup over cuDNN case of 3.57x on Maxwell and 5.18x on Pascal.

Tested all the output and gradients with large batch size & nInputPlane and various nOutputPlane.

Although the weight shape is (nOutputPlane) x (nInputPlane) x (kH) x (kW) which perfectly corresponds to cuDNN bindings, I didn't like it much since the weight tensor needs to be transposed back and forth when you need almost any kind of matmul/matvec. Don't know if it's critical, but I left it as is just to be safe anyway.

@killeent
Copy link

killeent commented Aug 4, 2017

@shrubb can you post the script you used for benchmarking?

Also cc pytorch/pytorch#1708

@shrubb
Copy link
Author

shrubb commented Aug 4, 2017

@killeent here it is. Hope it's just to check if my benchmarking is sane. The script's dirty, I just use it for rare checks.

I don't pretend to bring in some super-optimal implementation, there's actually little new in this PR. Just a better stub.

For anyone working on this, my previous attempts to employ "smarter" cublas usage are in this branch.

@soumith
Copy link
Member

soumith commented Aug 4, 2017

cc: @ajtulloch @ngimel

@soumith
Copy link
Member

soumith commented Aug 4, 2017

there's also https://github.com/szagoruyko/pyinn/blob/master/pyinn/conv2d_depthwise.py which is based on the Caffe code. it's a specialized conv. I presume you already benchmarked something in this order.

@ajtulloch
Copy link
Contributor

One thing to improve it is to add template specializations for the most common kH/kW/stride/dilation (e.g. IMO it's worth adding a template specialization for 3x3s1 and re-benchmarking mobilenet/shufflenet).

@szagoruyko
Copy link
Member

@soumith pyinn implementation is about the same with what Egor did, simple for loops and sum on grad wrt weight

@shrubb
Copy link
Author

shrubb commented Aug 5, 2017

@ajtulloch already tried this, gives just a negligible improvement:
image

@shrubb
Copy link
Author

shrubb commented Aug 5, 2017

Again, I believe this is NOT practical too. The code for weight gradients is also super simple, could be easily accelerated as well. Benchmarking it for fun.
image

@ajtulloch
Copy link
Contributor

@schrubb how much did you specialize in that benchmark? i.e ideally you'd have a specialization for statically-known kh, kw, stride dilation + separate paths in the kernel for definitely-inbounds/possibly-outbounds?

@shrubb
Copy link
Author

shrubb commented Aug 6, 2017

@ajtulloch hardcoded (kH, kW, padH, padW, strideH, strideW) = (3, 3, 1, 1, 1, 1). Dilation is also fixed (1,1) in both kernels.
I don't see why tracking inbound/outbound pixel will matter, memory reads are mostly contiguous anyway.
To me, all this seems insignificant compared to main arithmetic routines' load.

Anyway, once again, I don't understand what is this all benchmarking and fighting for another 0.001 seconds for. Nobody is going to use this. Everyone's on PyTorch, and NVIDIA is likely to release this kind of conv in cuDNN sooner or later.

@ngimel
Copy link

ngimel commented Aug 6, 2017

Pytorch shares backend with torch, so this could be exposed in pytorch. I don't know how it compares with pyinn, though, in terms of performance. Nvidia won't release it tomorrow, and people have been wanting to use depthwise separable convolutions for years, so anything helps.

@szagoruyko
Copy link
Member

@ngimel looks like cudnn7 supports grouped convolutions, would it be slower than such implementation?

@fmassa
Copy link
Contributor

fmassa commented Aug 7, 2017

@szagoruyko it seems that your pyinn kernels for depthwise convolutions are better than cudnn7 with grouped convolutions. But I'll let @ngimel comment more on that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants