Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalence to Lagrange polynomial padded convolution #2

Open
OlliNiemitalo opened this issue Feb 1, 2024 · 4 comments
Open

Equivalence to Lagrange polynomial padded convolution #2

OlliNiemitalo opened this issue Feb 1, 2024 · 4 comments

Comments

@OlliNiemitalo
Copy link

OlliNiemitalo commented Feb 1, 2024

The 2D differential convolution method seems to be equivalent to using Lagrange polynomial padding of degree kernel size - 1. Here is a minimal example supporting kernel sizes 3, 5, 7, using vertical Lagrange polynomial padding followed by horizontal Lagrange polynomial padding:

import numpy as np
import torch
import torch.nn.functional as F
from experiments.exp_filtering_methods import conv2d_diff

# Calculated using https://github.com/hamk-uas/DifferentialConv2d/blob/main/diff_conv2d/lagrange_constants/gen_lagrange_padding_coefs.py
lagrange_padding_coefs = [torch.tensor(c, dtype=torch.float32) for c in [
    [[1.0, -3.0, 3.0]],
    [[1.0, -5.0, 10.0, -10.0, 5.0],
     [5.0, -24.0, 45.0, -40.0, 15.0]],
    [[1.0, -7.0, 21.0, -35.0, 35.0, -21.0, 7.0],
     [7.0, -48.0, 140.0, -224.0, 210.0, -112.0, 28.0],
     [28.0, -189.0, 540.0, -840.0, 756.0, -378.0, 84.0]]
]]

def conv2d_lagrange_extrap(image, kernel, returns_padded=False):
    """ Lagrange polynomial extrapolation padding conv2d methods """
    # pad with zero
    k_size = kernel.shape[2]
    p = k_size // 2
    padded = F.pad(image, pad=(p, p, p, p))

    # extrapolation
    # top
    padded[:, :, :p, p:-p] = torch.einsum(
        '...yx,sy->sx', image[:, :, :k_size, :], lagrange_padding_coefs[k_size//2 - 1].flip(dims=(0, 1))
    )
    # bottom
    padded[:, :, -p:, p:-p] = torch.einsum(
        '...yx,sy->sx', image[:, :, -k_size:, :], lagrange_padding_coefs[k_size//2 - 1]
    )
    # left
    padded[:, :, :, :p] = torch.einsum(
        '...yx,sx->ys', padded[:, :, :, p:p+k_size], lagrange_padding_coefs[k_size//2 - 1].flip(dims=(0, 1))
    )
    # right
    padded[:, :, :, -p:] = torch.einsum(
        '...yx,sx->ys', padded[:, :, :, -p-k_size:-p], lagrange_padding_coefs[k_size//2 - 1]
    )
    
    if returns_padded:
        return F.conv2d(padded, kernel), padded
    return F.conv2d(padded, kernel)


M, N, = 5, 5

torch.manual_seed(0)
image = torch.round(torch.randn((1, 1, M, M))*5)
kernel = torch.round(torch.randn((1, 1, N, N))*5)

np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)

print("Random integer image:")
print(image.numpy())

print("\nRandom integer convolution kernel:")
print(kernel.numpy())

print("\nLagrange polynomial padding:")
lagrange_padded_convolved, lagrange_padded = conv2d_lagrange_extrap(image, kernel, returns_padded=True)
print(lagrange_padded.numpy())
print("Lagrange polynomial padded, then convolved:")
print(lagrange_padded_convolved.numpy())

diff_convolved = conv2d_diff(image, kernel)
print("\nDiff convolved:")
print(diff_convolved.numpy())

print("\nDifference between Lagrange polynomial padded convolved and diff convolved:")
print((diff_convolved - lagrange_padded_convolved).numpy())

Output:

Random integer image:
[[[[ -6.  -6.  -1.  -2.   4.]
   [  3.  -2. -11.   2.  -1.]
   [  7.   1.   1.   4.  -1.]
   [ -1.  -3.   6.  10.   0.]
   [  3.  -2.  -4. -12.  -1.]]]]

Random integer convolution kernel:
[[[[-4.  3.  1. -1. -3.]
   [ 5.  2. -3.  0.  6.]
   [ 6. -7. 13. -2.  2.]
   [ 8. 10. -2. -2. -5.]
   [ 6. -1.  4.  2. -4.]]]]

Lagrange polynomial padding:
[[[[17820.  4010.   144.    97.   306.  -230.    50.  5269. 22112.]
   [ 4738.  1049.    18.    13.    81.   -52.    19.  1378.  5788.]
   [  173.    34.    -6.    -6.    -1.    -2.     4.    54.   209.]
   [ -423.   -86.     3.    -2.   -11.     2.    -1.  -122.  -527.]
   [    9.    14.     7.     1.     1.     4.    -1.   -33.  -119.]
   [  135.    35.    -1.    -3.     6.    10.     0.   -26.   -63.]
   [  228.    54.     3.    -2.    -4.   -12.    -1.    88.   348.]
   [ 2168.   524.    74.    34.   -16.  -132.    -6.  1034.  4024.]
   [ 9442.  2274.   298.   153.     4.  -458.   -16.  4073. 16078.]]]]
Lagrange polynomial padded, then convolved:
[[[[-37294. -10606.    177.  -6228. -34906.]
   [-16081.  -4098.     13.  -3570. -17936.]
   [  -156.     62.    -15.  -1194.  -4813.]
   [ 17394.   4883.    220.  -4767. -15136.]
   [ 80126.  21810.   2005. -22072. -78736.]]]]

Diff convolved:
[[[[-37294. -10606.    177.  -6228. -34906.]
   [-16081.  -4098.     13.  -3570. -17936.]
   [  -156.     62.    -15.  -1194.  -4813.]
   [ 17394.   4883.    220.  -4767. -15136.]
   [ 80126.  21810.   2005. -22072. -78736.]]]]

Difference between Lagrange polynomial padded convolved and diff convolved:
[[[[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]]]

With stride and dilation the equivalence is not as clear. There I think the equivalence would appear by strided/dilated padding.

In any case, I find the paper including the theoretical treatment valuable and thought provoking. Also the fact that the kernel can be modified instead of doing padding is interesting from the point of view of potentially reducing memory accesses, and maybe from the gradient computation point of view. I would suggest looking into the equivalence and mentioning it in a v2 of the paper.

The diff conv approach could also be done in a horizontal and vertical separated way, similar to what was done in my Lagrange polynomial padding code.

@kuangdai
Copy link
Collaborator

kuangdai commented Feb 2, 2024

Thanks for the insights. I did a similar test before but the results were different. Maybe I did something wrong. Your code is convincing.

@kuangdai
Copy link
Collaborator

kuangdai commented Feb 2, 2024

I will try to prove this mathematically.

So, my paper basically states that Lagrange-based padding can preserve the differential operator associated with a kernel. Because padding is easier than the kernel transformation we proposed, I think this is what we should do in implementation.

This inspired me that we can contribute a new padding mode to pytorch, i.e.,

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, 
                dilation=1, groups=1, bias=True, padding_mode='lagrange', device=None, dtype=None)

There are many things to implement so that it works compatibly with the other modes (such as stride, dilation, and even kernel sizes). Are you happy to join this initiative?

@OlliNiemitalo
Copy link
Author

OlliNiemitalo commented Feb 2, 2024

I'm happy that you took this well. 😊

So, my paper basically states that Lagrange-based padding can preserve the differential operator associated with a kernel.

Yes, on the condition that the underlying continuous-domain input is locally equal to the Lagrange polynomial. What I also found when computing the transformation matrixes is that the differential operator is preserved even if another original location for the differential operator is chosen than the center of the original window. It's just the amount of spatial shift that defines the transformation. That's great, because I was initially worried what happens in the kernel transformation if the network has come up with a useful differential operator with non-zero spatial offset. This also opens the door to kernel transformations that do sub-pixel shifts of a differential operator implemented by an existing convolution kernel, probably useful for something. There a kernel transformation approach would be better than shifting the whole image. I'm repeating myself, but a horizontal vertical decomposition of the transformation would make things computationally cheaper: transformation(kernel, vertical_shift, horizontal_shift) = vertical_transformation(horizontal_transformation(kernel, horizontal_shift), vertical_shift).

About helping with a lagrange padding mode implementation, yes I can do something. But I don't have experience using PyTorch other than having written the code above, so maybe I should not directly write PyTorch code. What I already did is:

  • Provided precalculated coefficients that could be used in the implementation or for verifying any fast direct formula. Here are my insights: The coefficients blow up numerically with increasing Lagrange polynomial degree, 26 (padding size 13) being a maximum for which float64 can hold each coefficient exactly. It is potentially useful to allow limiting the polynomial degree, say for always doing a large repl or a linear padding, but I don't know if PyTorch allows numerical padding method parameters. Also when the image size is less than polynomial degree + 1, then the polynomial degree should probably be automatically reduced. That's why I have provided coefficients for the same amount of padding for each Lagrange polynomial degree. It might make sense to allow even larger paddings for smaller degrees because those are more stable numerically.
  • Provided the above example code that implements a 2D version of the multidimensional padding decomposition logic: "pad also the paddings of the already done dimensions". The code doesn't have the coefficients for even kernel sizes, but those are in the larger precalculated list. Using those requires changing the index to lagrange_padding_coefs in the code.

@olli4
Copy link

olli4 commented Oct 3, 2024

I found, empirically, a quick way to compute the Lagrange polynomial extrapolation padding coefficients:

import numpy as np
from scipy.special import binom

def make_lagrange_padding_coefs(max_num_padding, num_predictors):
  for k in range(num_predictors):
    if k == 0:
      cumsumconsts = np.array([[1]])
    else:
      b = np.array(binom(k, np.arange(k+1)), dtype=int)
      b = b*np.expand_dims(b, 1)
      b[1:,1:] += cumsumconsts
      cumsumconsts = b
  # The code below could be included in the above for-loop to compute coefs for many num_predictors in one go
  c = np.tile(cumsumconsts[0], (max(max_num_padding - k, 0), 1))
  for i in range(1, k + 1):
    c = np.vstack([cumsumconsts[i], c])
    if c.shape[0] > max_num_padding:
      c = c[:max_num_padding,:]
    c = np.cumsum(c, axis=0)
  c = c*(-1)**np.flip(np.arange(k+1))
  return c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants