Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to export models to ONNX #39

Open
lanyuer opened this issue Jul 24, 2024 · 3 comments
Open

How to export models to ONNX #39

lanyuer opened this issue Jul 24, 2024 · 3 comments

Comments

@lanyuer
Copy link

lanyuer commented Jul 24, 2024

Thank you for your work, the performance of this model is quite good. I would like to deploy and use it. Is there a way to export it to ONNX?

@skeskinen
Copy link

I've made onnx version of the denoiser here: https://github.com/skeskinen/resemble-denoise-onnx-inference
It's deployed in https://smartmediacutter.com/ and works quite nicely.

This does not do the enhancer and I'm not sure how easy it is to do the enhancer model in onnx. The enhancer is too slow to run for my usecase at the moment.

@lanyuer
Copy link
Author

lanyuer commented Jul 26, 2024

I've made onnx version of the denoiser here: https://github.com/skeskinen/resemble-denoise-onnx-inference It's deployed in https://smartmediacutter.com/ and works quite nicely.

This does not do the enhancer and I'm not sure how easy it is to do the enhancer model in onnx. The enhancer is too slow to run for my usecase at the moment.

Thank you so much for sharing it! It's really helpful and I appreciate your work.

@alexey-laletin-singularitynet
Copy link

alexey-laletin-singularitynet commented Oct 14, 2024

My workaround led me to these fixes for make it exportable into ONNX

Comment the Generator because it's not exportable by ONNX

resemble-enhance/enhancer/lcfm/cfm.py

def _sample_ψ0(self, x: Tensor):
    """
    Args:
        x: (b c t), which implies the shape of ψ0
    """
    shape = list(x.shape)
    shape[1] = self.output_dim
    # if self.training:
    #     g = None
    # else:
    #     g = torch.Generator(device=x.device)
    #     g.manual_seed(0)  # deterministic sampling during eval
    ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype) # , generator=g)
    return ψ0

Change .expand function to .repeat.

resemble/enhancer/univnet/alias_free_torch/resample.py

def forward(self, x):
    _, C, _ = x.shape
    x = F.pad(x, (self.pad, self.pad), mode='replicate')
    weight = self.filter.repeat([int(x.shape[1]), 1, 1])
    x = self.ratio * F.conv_transpose1d(x, weight, stride=self.stride, groups=C) # self.filter.expand(C, -1, -1)
    shape = x.shape
    shape = [int(elem) for elem in shape]
    x = torch.reshape(x, shape)
    x = x[..., self.pad_left:-self.pad_right]
 
    return x

Write your own custom Tensor.unfold function

resemble_enhance/enhancer/univnet/lvcnet.py

    def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
        """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
        Args:
            x (Tensor): the input sequence (batch, in_channels, in_length).
            kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
            bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
            dilation (int): the dilation of convolution.
            hop_size (int): the hop_size of the conditioning sequence.
        Returns:
            (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
        """
        batch, _, in_length = x.shape
        batch, _, out_channels, kernel_size, kernel_length = kernel.shape
 
        assert in_length == (
            kernel_length * hop_size
        ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
 
        padding = dilation * int((kernel_size - 1) / 2)
        x = F.pad(x, (padding, padding), "constant", 0)  # (batch, in_channels, in_length + 2*padding)
        x = custom_unfold_dim_2(x, hop_size + 2 * padding, hop_size)
        # x = x.unfold(2, hop_size + 2 * padding, hop_size)  # (batch, in_channels, kernel_length, hop_size + 2*padding)
 
        if hop_size < dilation:
            x = F.pad(x, (0, dilation), "constant", 0)
        x = custom_unfold_dim_3(x, dilation, dilation)
        # x = x.unfold(3, dilation, dilation)  # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
        x = x[:, :, :, :, :hop_size]
        x = x.transpose(3, 4)  # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
        x = custom_unfold_dim_4(x, kernel_size, 1)
        # x = x.unfold(4, kernel_size, 1)  # (batch, in_channels, kernel_length, dilation, _, kernel_size)
 
        o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
        o = o.to(memory_format=torch.channels_last_3d)
        bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
        o = o + bias
        o = o.contiguous().view(batch, out_channels, -1)
 
        return o
 
def custom_unfold_dim_2(x: torch.Tensor, window_size: int, step: int):
    dim = 2
    subtensors = [x[:, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
    result = torch.stack(subtensors, dim=dim)
    return result
 
def custom_unfold_dim_3(x: torch.Tensor, window_size: int, step: int):
    dim = 3
    subtensors = [x[:, :, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
    result = torch.stack(subtensors, dim=dim)
    return result
 
def custom_unfold_dim_4(x: torch.Tensor, window_size: int, step: int):
    dim = 4
    subtensors = [x[:, :, :, :, i:i + window_size, ...] for i in range(0, x.size(dim) - window_size + 1, step)]
    result = torch.stack(subtensors, dim=dim)
    return result

Export script

P.S. I had to rename the resemble-enhance source folder to src to avoid problems with package resemble-enhance which is also was installed in my venv.

import logging
import time
import os

import click
import torch
import torchaudio
from torch.nn.functional import pad
from torchaudio.functional import resample

from src.enhancer.inference import load_enhancer
from src.hparams import HParams
from src.inference import merge_chunks, remove_weight_norm_recursively

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

@click.command()
@click.option('--wav-path', type=click.Path(exists=True), help='Path to input wav file')
@click.option('--save-path', type=str, default='output.wav', help='Path to save output wav file')
@click.option('--run-dir', type=str, default=None, help='Path to run directory')
@click.option('--device', type=str, default='cuda', help='Device to use for computation')
@click.option('--nfe', type=int, default=32, help='Number of function evaluations')
@click.option('--solver', type=str, default='midpoint', help='Numerical solver to use')
@click.option('--lambd', type=float, default=0.5, help='Denoise strength')
@click.option('--tau', type=float, default=0.5, help='CFM prior temperature')
@click.option('--chunk-seconds', type=float, default=30.0, help='Length of each chunk in seconds')
@click.option('--overlap-seconds', type=float, default=1.0, help='Overlap between chunks in seconds')
@click.option('--export-onnx', type=bool, default=False, help='Do you need to export enhancer model to ONNX?')
@click.option('--onnx-path', type=click.Path(exists=True), default="onnx", help='The path where ONNX files will be saved')
def main(
    wav_path: str,
    save_path: str = "output.wav",
    run_dir: str | None = None,
    device: str = "cuda",
    nfe: int = 32,
    solver: str = "midpoint",
    lambd: float = 0.5,
    tau: float = 0.5,
    chunk_seconds: float = 30.0, 
    overlap_seconds: float = 1.0,
    export_onnx: bool = False,
    onnx_path: str = "onnx",
):
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        if device == "cuda":
            torch.cuda.empty_cache()
    else:
        device = "cpu"
    
    enhancer = load_enhancer(run_dir, device)
    enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
    
    enhancer.eval()
    enhancer.lcfm.eval()
    remove_weight_norm_recursively(enhancer)
    hp: HParams = enhancer.hp
    enhancer.to(device)
    
    dwav, sr = torchaudio.load(wav_path)
    dwav = dwav.mean(dim=0)
    
    dwav = resample(
        dwav,
        orig_freq=sr,
        new_freq=hp.wav_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="sinc_interp_kaiser",
        beta=14.769656459379492,
    )
    
    result_audio_length = dwav.shape[-1]

    start_time = time.perf_counter()

    chunk_length = int(sr * chunk_seconds)
    overlap_length = int(sr * overlap_seconds)
    hop_length = chunk_length - overlap_length
    
    chunks = [dwav[i:i + chunk_length] for i in range(0, dwav.shape[-1], hop_length)]
    input_chunks = torch.stack([pad(chunk, (0, chunk_length - len(chunk))) for chunk in chunks], dim=0)
    
    abs_max = input_chunks.abs().max(dim=1, keepdim=True).values
    abs_max[abs_max == 0] = 10e-7
    input_chunks = input_chunks / abs_max
    input_chunks = input_chunks.to(device)
    
    with torch.inference_mode() and torch.no_grad():
        output = enhancer(input_chunks).to("cpu")
        output = output * abs_max
    
    audio = merge_chunks(output, chunk_length, hop_length, sr=hp.wav_rate)
    
    elapsed_time = time.perf_counter() - start_time
    logger.info(f"Elapsed time: {elapsed_time:.3f} s, {audio.shape[-1] / elapsed_time / 1000:.3f} kHz")
    
    torchaudio.save(save_path, audio[None, :result_audio_length], hp.wav_rate)
    
    if export_onnx:
        logger.info("Exporting enhancer model to ONNX")
        with torch.no_grad():
            mel_spectrogram = enhancer.to_mel(input_chunks)
            normalizer_result = enhancer.normalizer(mel_spectrogram)
            lcfm_result = enhancer.lcfm(normalizer_result)
            
            logger.info("Exporting normalizer model to ONNX")
            torch.onnx.export(
                enhancer.normalizer,
                mel_spectrogram,
                os.path.join(onnx_path, "normalizer.onnx"),
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'],
                dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
                            'output' : {0 : 'batch_size', 2 : 'mel_length'}}
            )
            
            logger.info("Exporting lcfm model to ONNX")
            torch.onnx.export(
                enhancer.lcfm,
                normalizer_result,
                os.path.join(onnx_path, "lcfm.onnx"),
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'],
                dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
                            'output' : {0 : 'batch_size', 2 : 'mel_length'}}
            )
            
            logger.info("Exporting vocder model to ONNX")
            torch.onnx.export(
                enhancer.vocoder,
                lcfm_result,
                os.path.join(onnx_path, "vocder.onnx"),
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'],
                dynamic_axes={'input' : {0 : 'batch_size', 2 : 'mel_length'},
                            'output' : {0 : 'batch_size', 1 : 'audio_length'}}
            )
    

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants