-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equivalence to Lagrange polynomial padded convolution #2
Comments
Thanks for the insights. I did a similar test before but the results were different. Maybe I did something wrong. Your code is convincing. |
I will try to prove this mathematically. So, my paper basically states that Lagrange-based padding can preserve the differential operator associated with a kernel. Because padding is easier than the kernel transformation we proposed, I think this is what we should do in implementation. This inspired me that we can contribute a new padding mode to pytorch, i.e., torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True, padding_mode='lagrange', device=None, dtype=None) There are many things to implement so that it works compatibly with the other modes (such as stride, dilation, and even kernel sizes). Are you happy to join this initiative? |
I'm happy that you took this well. 😊
Yes, on the condition that the underlying continuous-domain input is locally equal to the Lagrange polynomial. What I also found when computing the transformation matrixes is that the differential operator is preserved even if another original location for the differential operator is chosen than the center of the original window. It's just the amount of spatial shift that defines the transformation. That's great, because I was initially worried what happens in the kernel transformation if the network has come up with a useful differential operator with non-zero spatial offset. This also opens the door to kernel transformations that do sub-pixel shifts of a differential operator implemented by an existing convolution kernel, probably useful for something. There a kernel transformation approach would be better than shifting the whole image. I'm repeating myself, but a horizontal vertical decomposition of the transformation would make things computationally cheaper: About helping with a
|
I found, empirically, a quick way to compute the Lagrange polynomial extrapolation padding coefficients: import numpy as np
from scipy.special import binom
def make_lagrange_padding_coefs(max_num_padding, num_predictors):
for k in range(num_predictors):
if k == 0:
cumsumconsts = np.array([[1]])
else:
b = np.array(binom(k, np.arange(k+1)), dtype=int)
b = b*np.expand_dims(b, 1)
b[1:,1:] += cumsumconsts
cumsumconsts = b
# The code below could be included in the above for-loop to compute coefs for many num_predictors in one go
c = np.tile(cumsumconsts[0], (max(max_num_padding - k, 0), 1))
for i in range(1, k + 1):
c = np.vstack([cumsumconsts[i], c])
if c.shape[0] > max_num_padding:
c = c[:max_num_padding,:]
c = np.cumsum(c, axis=0)
c = c*(-1)**np.flip(np.arange(k+1))
return c |
The 2D differential convolution method seems to be equivalent to using Lagrange polynomial padding of degree kernel size - 1. Here is a minimal example supporting kernel sizes 3, 5, 7, using vertical Lagrange polynomial padding followed by horizontal Lagrange polynomial padding:
Output:
With stride and dilation the equivalence is not as clear. There I think the equivalence would appear by strided/dilated padding.
In any case, I find the paper including the theoretical treatment valuable and thought provoking. Also the fact that the kernel can be modified instead of doing padding is interesting from the point of view of potentially reducing memory accesses, and maybe from the gradient computation point of view. I would suggest looking into the equivalence and mentioning it in a v2 of the paper.
The diff conv approach could also be done in a horizontal and vertical separated way, similar to what was done in my Lagrange polynomial padding code.
The text was updated successfully, but these errors were encountered: