Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file .../csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu #19

Open
dwromero opened this issue Feb 7, 2024 · 8 comments

Comments

@dwromero
Copy link

dwromero commented Feb 7, 2024

Hi Dan & Hermann,

I am trying to run some experiments with FlashFFTConv, but I am afraid I am encountering the following error:

ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
invalid argument

For debugging, I am running the following:

y = fftconv_fn(x.to(dtype=fftconv_fn.dtype).contiguous(), k.float()).to(dtype=x.dtype)

where fftconv_fn is a FlashFFTConv element with use_32_butterfly=True. Both torch.float16 and torch.bfloat16 lead to the same error.

Any help on how to solve this issue would be much appreciated!

David

@dwromero dwromero changed the title ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file ..,/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu Feb 7, 2024
@dwromero dwromero changed the title ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file ..,/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file .../csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu Feb 7, 2024
@Kumbong
Copy link
Collaborator

Kumbong commented Feb 7, 2024

Hey David,

Thank you so much for your interest in using FlashFFTConv. I am wondering, what GPU card are you running this on (A100, H100 etc) ? We only tested on A100 and H100 and I suspect the issue comes from using a card that does not have enough as much total shared memory as the A100. We can see how to extend this if needed. Also what size of FFTConv are you currenlty using i.e 32K ? 16k ?

@dwromero
Copy link
Author

dwromero commented Feb 7, 2024

Hi Hermann,

I am wondering, what GPU card are you running this on (A100, H100 etc) ?
Oh, good point! I am using an RTX 6000 Ada. I will check whether I get the same error in a A100.

Also what size of FFTConv are you currenlty using i.e 32K ? 16k ?
+-> I tried different input lenghts with FlashFFTConv objects of 2*seq_length to make them causal. I tried inputs of length:

  • 16384, 32768, 65536.

Funnily, the length that causes the problem is 16384 - the shortest one! The other lengths do not raise that error.

David

@dwromero
Copy link
Author

dwromero commented Feb 8, 2024

Hi Hermann,

I tried it out in a A100-40GB, but unfortunately, I keep getting errors related to the package :/

CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:774                                                                                                                                                                  
misaligned address
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: misaligned address
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/pytorch/pytorch/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7f1e2d99c8f9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f1e2d951bb6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3c2 (0x7f1e3898fe12 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0xe5c485 (0x7f1dcf8b5485 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xe59644 (0x7f1dcf8b2644 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x483b00 (0x7f1e2ca1cb00 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #6: c10::TensorImpl::~TensorImpl() + 0x9 (0x7f1e2d978419 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #7: <unknown function> + 0x74b788 (0x7f1e2cce4788 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: THPVariable_subclass_dealloc(_object*) + 0x296 (0x7f1e2cce4a96 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x136991 (0x560e2e34f991 in /usr/bin/python)
frame #10: <unknown function> + 0x13678c (0x560e2e34f78c in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x560e2e34ec52 in /usr/bin/python)
frame #12: <unknown function> + 0x25b035 (0x560e2e474035 in /usr/bin/python)
frame #13: _PyEval_EvalFrameDefault + 0xa33b (0x560e2e365eeb in /usr/bin/python)
frame #14: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #15: PyObject_Call + 0x122 (0x560e2e382492 in /usr/bin/python)
frame #16: _PyEval_EvalFrameDefault + 0x2a27 (0x560e2e35e5d7 in /usr/bin/python)
frame #17: <unknown function> + 0x1687f1 (0x560e2e3817f1 in /usr/bin/python)frame #18: _PyEval_EvalFrameDefault + 0x198c (0x560e2e35d53c in /usr/bin/python)
frame #19: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x8ac (0x560e2e35c45c in /usr/bin/python)
frame #21: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x2a27 (0x560e2e35e5d7 in /usr/bin/python)
frame #23: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #25: <unknown function> + 0x13f9c6 (0x560e2e3589c6 in /usr/bin/python)
frame #26: PyEval_EvalCode + 0x86 (0x560e2e44e256 in /usr/bin/python)
frame #27: <unknown function> + 0x23ae2d (0x560e2e453e2d in /usr/bin/python)
frame #28: <unknown function> + 0x15ac59 (0x560e2e373c59 in /usr/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #30: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #32: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #33: <unknown function> + 0x252c2d (0x560e2e46bc2d in /usr/bin/python)
frame #34: Py_RunMain + 0x128 (0x560e2e46a8c8 in /usr/bin/python)
frame #35: Py_BytesMain + 0x2d (0x560e2e44102d in /usr/bin/python)
frame #36: <unknown function> + 0x29d90 (0x7f1e3b39ad90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #37: __libc_start_main + 0x80 (0x7f1e3b39ae40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #38: _start + 0x25 (0x560e2e440f25 in /usr/bin/python)

@Kumbong
Copy link
Collaborator

Kumbong commented Feb 8, 2024

Hi David,

Do you mind sharing what version of Pytorch and CUDA you are using so that I can try to reproduce your error on my end and see what the issue could be? I suspect this could be from the version of Pytorch or CUDA.

We tested on

PyTorch 2.0 and
CUDA version 12.1 and toolkit version 12.1

@dwromero
Copy link
Author

dwromero commented Feb 9, 2024

I am testing on PyTorch '2.2.0a0+81ea7a4'. CUDA and Toolkit versions 12.3. Do you think this might be causing the problem?

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

By the way, I created a small benchmark to pinpoint the errors:

import torch
from flashfftconv import FlashFFTConv


# Instantiate FlashFFTConv versions of all possible lengths
def instantiate_fftconv_functions(
    max_length: int = 64 * 32768,  # Default max_length is about 2M elements
    start_length: int = 256,
    dtype: torch.dtype = torch.float16,
    use_32_butterfly: bool = True,
) -> dict[str, FlashFFTConv]:
    """
    Initializes flash FFT convolution functions across a range of sequence lengths.

    The sequence lengths start from a predefined minimum of 256 and doubles them until they exceed the specified maximum
    length.

    Parameters:
    - max_length (int): The maximum sequence length for which to instantiate FFT convolution functions.
                        Defaults to 2,097,152 (64 * 32768).
    - dtype (torch.dtype): The data type to use for the FFT convolution operations. Defaults to torch.float16.
    - use_32_butterfly (bool): A flag indicating whether to use 32-bit precision for the butterfly operations.

    Returns:
    - dict[str, FlashFFTConv]: A dictionary mapping sequence lengths to their corresponding FFT convolution functions.
                               The type of the keys is set to str to allow it to be used within torch.ModuleDict.
    """
    fftconv_functions = {}
    while start_length <= max_length:
        fftconv_functions[str(start_length)] = FlashFFTConv(
            start_length, dtype=dtype, use_32_butterfly=use_32_butterfly
        )
        start_length = start_length * 2
    return fftconv_functions


class DummyModule(torch.nn.Module):
    def __init__(self,
                 fftfns,
                 n_hidden,
                 dtype):
        super().__init__()
        self.fftfns = torch.nn.ModuleDict(fftfns)
        weights = self._create_weights(n_hidden, dtype)
        self.weights = torch.nn.ParameterDict(weights)

    def forward(self, x, key):
        return self.fftfns[key](x, self.weights[key])

    def _create_weights(self, n_hidden, dtype) -> dict[str, torch.nn.Parameter]:
        weights = {}
        for key in self.fftfns.keys():
            weights[key] = torch.nn.Parameter(torch.randn([n_hidden, int(key) // 2], dtype=dtype))
        return weights


if __name__ == "__main__":
    DTYPE = torch.float16
    N_HIDDEN = 128
    BATCH_SIZE = 4

    fftfns = instantiate_fftconv_functions(
        max_length=64*32768,
        start_length=512,
        dtype=DTYPE,
        use_32_butterfly=True)
    # create dummy module.
    model = DummyModule(fftfns, n_hidden=N_HIDDEN, dtype=torch.float32)
    model.cuda()
    model.train()

    print('-' * 50)

    # Iterate (fwd and bwd) through sequences:
    for key in model.fftfns.keys():
        # Create an input tensor of appropriate shape
        # Example: tensor shape [batch_size, sequence_length]
        # Both the weights and the seq lengths are of length [key // 2] to perform causal convolutions.
        input_length = int(key)
        input_tensor = torch.randn(BATCH_SIZE, N_HIDDEN, input_length // 2, dtype=DTYPE, device="cuda")

        layer = model.fftfns[key]
        print(f"Conv layer: {layer}, seq_len = {layer.seqlen}, dtype = {layer.dtype}, use_32_butterfly = {layer.use_32_butterfly}")
        print(f"Input size: {input_tensor.shape}")

        # Run the model's forward pass
        output = model(input_tensor, key)

        # Compute a simple loss (e.g., mean squared error against a target of the same shape as output)
        target = torch.randn_like(output)  # Dummy target tensor of the same shape as the model's output
        loss = torch.nn.functional.mse_loss(output, target)
        print(f"Loss: {loss}")

        # Run the backward pass to compute gradients
        loss.backward()
        print(f"Gradient computed. Gradient on weights: {model.weights[key]._grad.sum()}")
        model.zero_grad()
        print(f"Gradient resetted.")

        print('-' * 50)


    print('Done')
  • On a A6000 ADA I get the following:

    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 256])
    Loss: 128.625
    Gradient computed. Gradient on weights: -1.1250529289245605
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 512])
    Loss: 258.0
    Gradient computed. Gradient on weights: 2.563692092895508
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 1024])
    Loss: 515.5
    Gradient computed. Gradient on weights: -3.3228659629821777
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 2048])
    Loss: 1026.0
    Gradient computed. Gradient on weights: 6.598326683044434
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 4096])
    Loss: inf
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 620 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:710
    invalid argument
    Gradient computed. Gradient on weights: nan
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 8192])
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 561 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:665
    invalid argument
    Loss: nan
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 820 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:882
    invalid argument
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 16384])
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 702 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:774
    invalid argument
    Loss: nan
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 987 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:1045
    invalid argument
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 32768])
    Loss: inf
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 65536])
    Loss: inf
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 131072])
    Loss: inf
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_complex_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 267 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:356
    invalid argument
    Gradient computed. Gradient on weights: -8.537776947021484
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 262144])
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 356 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu:436
    invalid argument
    Loss: 1.0
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 435 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:495
    invalid argument
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 524288])
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 474 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu:546
    invalid argument
    Loss: 1.0
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 569 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:625
    invalid argument
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
    Input size: torch.Size([4, 128, 1048576])
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 474 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu:546
    invalid argument
    Loss: 1.0
    ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 569 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
    CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:625
    invalid argument
    Gradient computed. Gradient on weights: 0.0
    Gradient resetted.
    --------------------------------------------------
    Done
    

    So, it works until sequences of length 2048.

  • On a A100 - 40GB I see the following:

--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.1250529289245605
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.387176513671875
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 517.5
Gradient computed. Gradient on weights: -2.1371450424194336
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 5.117058753967285
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
Gradient computed. Gradient on weights: 5.128373146057129
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done

I am afraid the results are the same. It only works until sequences of length 2048.

EDIT: I verified with nvcr.io/nvidia/pytorch:23.05-py3 and installing as mentioned in the repo with:

pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
pip install git+https://github.com/HazyResearch/flash-fft-conv.git

Unfortunately, I get the same results. I tried with both dtype=bfloat16 and float16. Unfortunately, the response is the same. EDIT2: Building from source produces the same result with both dtypes.

  • On A6000 Ada:
 --------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 129.0
Gradient computed. Gradient on weights: -1.1330515146255493
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 2.6320414543151855
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 516.0
Gradient computed. Gradient on weights: -3.4659423828125
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1032.0
Gradient computed. Gradient on weights: 5.9667158126831055
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: 2064.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 819 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:909
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 828 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:932
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 1012 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:1106
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 1198 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:1256
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: 16512.0
Gradient computed. Gradient on weights: -3.5391926765441895
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: 33024.0
Gradient computed. Gradient on weights: -7.784242153167725
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: 66048.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_complex_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 271 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:360
invalid argument
Gradient computed. Gradient on weights: -8.243345260620117
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 357 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:437
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 439 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:530
invalid argument
Gradient computed. Gradient on weights: -103.35313415527344
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 475 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:547
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 603 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:659
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 475 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:547
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 603 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:659
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done

FINAL EDIT

I found out where the problem lies. After trying on A100-40GB, A6000 ADA and A100-80GB, I noticed that all sequence lengths work only on the A100-80GB, and using torch.bfloat16. For long sequences, A100-80GB does not raise an error, but the loss becomes infinite.

  • A100-80GB - torch.float16
-------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.124933123588562
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.387176513671875
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 517.5
Gradient computed. Gradient on weights: -2.1371450424194336
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 5.117058753967285
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
Gradient computed. Gradient on weights: 5.128373146057129
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done
  • A100-80GB - torch.bfloat16
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 129.0
Gradient computed. Gradient on weights: -1.1331242322921753
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.492912769317627
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 520.0
Gradient computed. Gradient on weights: -2.112046718597412
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1032.0
Gradient computed. Gradient on weights: 4.173070430755615
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: 2064.0
Gradient computed. Gradient on weights: 7.480684280395508
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: 4128.0
Gradient computed. Gradient on weights: -7.906074047088623
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: 8256.0
Gradient computed. Gradient on weights: 3.4166970252990723
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: 16512.0
Gradient computed. Gradient on weights: -5.117952346801758
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: 33024.0
Gradient computed. Gradient on weights: -7.764023780822754
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: 66048.0
Gradient computed. Gradient on weights: -65.47335052490234
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: 132096.0
Gradient computed. Gradient on weights: -72.84762573242188
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: 264192.0
Gradient computed. Gradient on weights: -73.87054443359375
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: 528384.0
Gradient computed. Gradient on weights: -95.62330627441406
Gradient resetted.
--------------------------------------------------
Done

I hope these insights help a bit getting clarity on what's missing :) From my side, I'll continue using A100-80GBs in the meantime.

Thank you!

David

@catid
Copy link

catid commented Feb 11, 2024

Seeing the same issue here on RTX 4090:

ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
invalid argument

Running the example code in the README (with some obvious fixes):

# https://github.com/HazyResearch/flash-fft-conv
from flashfftconv import FlashFFTConv

import torch

# size of the FFT
my_flashfftconv = FlashFFTConv(32768, dtype=torch.bfloat16) # generally more stable!
my_flashfftconv.cuda()

# B is batch size, H is model dimension, L is sequence length
B = 16
H = 768
# input can be smaller than FFT size, but needs to be divisible by 2
L = 16384

# the input, B H L
x = torch.randn(B, H, L, dtype=torch.bfloat16).cuda() # same type as the input
k = torch.randn(H, L, dtype=torch.float32).cuda() # kernel needs to be fp32 for now

out = my_flashfftconv(x, k)

(sssm) ➜ spectral_ssm git:(main) ✗ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

@catid
Copy link

catid commented Feb 11, 2024

Running the unit test shared above I see the first error here:

Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 620 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:710
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.

@DanFu09
Copy link
Contributor

DanFu09 commented Feb 12, 2024

Thanks for the detailed bug report! I believe the issues on non-A100 are related to #6. We’ll have to take a closer look at the others.

It may be a little while until we can get to it (I’m busy with faculty search, and @Kumbong is about to visit a bunch of PhD programs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants