Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/kaiidams/NeMoOnnxSharp i…
Browse files Browse the repository at this point in the history
…nto develop
  • Loading branch information
kaiidams committed Aug 13, 2023
2 parents 6e63180 + 1bf353b commit f285765
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 325 deletions.
51 changes: 10 additions & 41 deletions Python/export_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import nemo
import importlib
from omegaconf import OmegaConf


def get_class(cls_path):
Expand All @@ -12,46 +12,15 @@ def get_class(cls_path):
def export(cls_path: str, model_name: str):
cls = get_class(cls_path)
model = cls.from_pretrained(model_name)
print(model)
model.export(f'{model_name}.onnx')
print(OmegaConf.to_yaml(model._cfg))



#export("nemo_asr.models.EncDecClassificationModel")

#import nemo
#import nemo.collections.asr as nemo_asr
#import nemo.collections.asr


cls_path = "nemo.collections.asr.models.EncDecClassificationModel"
#export(cls_path, 'vad_marblenet')
cls = get_class(cls_path)
model = cls.from_pretrained('vad_marblenet')
print(model.preprocessor)
print(model._cfg.preprocessor)
import librosa
import torch
import numpy as np
import struct
from glob import glob
for wave_file in glob("../test_data/*.wav"):
audio, sample_rate = librosa.load(wave_file, sr=16000)
audio_signal = torch.from_numpy(audio / 32768.0).to(torch.float32)
audio_signal = torch.unsqueeze(audio_signal, 0)
audio_signal_len = torch.tensor([audio.shape[0]], dtype=torch.int64)
processed_signal, processed_signal_len = model.preprocessor(
input_signal=audio_signal, length=audio_signal_len,
)
print(processed_signal.shape, processed_signal_len)
processed_signal = processed_signal[0]
bin_file = wave_file.replace('.wav', '.bin')
with open(bin_file, "wb") as fp:
fp.write(struct.pack('2i', *list(processed_signal.shape)))
fp.write(processed_signal.numpy().tobytes())

# audio_signal, audio_signal_len = batch
# audio_signal, audio_signal_len = audio_signal.to(vad_model.device), audio_signal_len.to(vad_model.device)
# processed_signal, processed_signal_len = vad_model.preprocessor(
# input_signal=audio_signal, length=audio_signal_len,
# )
cls_path = 'nemo.collections.asr.models.EncDecClassificationModel'
cls_path = 'nemo.collections.asr.models.EncDecCTCModel'
cls_path = 'nemo.collections.asr.models.EncDecClassificationModel'
model_name = 'vad_marblenet'
model_name = 'stt_en_quartznet15x5'
model_name = 'stt_en_jasper10x5dr'
model_name = 'commandrecognition_en_matchboxnet3x1x64_v2'
export(cls_path, model_name)
9 changes: 6 additions & 3 deletions Python/make_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Katsuya Iida. All Rights Reserved.
# See LICENSE in the project root for license information.

import librosa
import torch
from nemo.collections.asr.modules import (
Expand All @@ -7,7 +10,7 @@


def main():
wavpath = "61-70968-0000.wav"
wavpath = "../NemoOnnxSharp.Tests/Data/61-70968-0000.wav"
sr = 16000
audio_signal, sr = librosa.load(wavpath, sr=sr, mono=True)
assert audio_signal.ndim == 1
Expand Down Expand Up @@ -35,7 +38,7 @@ def convert_mel_spectrogram(audio_signal, length):
processed_signal, processed_signal_length = preprocessor(input_signal=audio_signal, length=length)
print(processed_signal, processed_signal_length)
print(processed_signal.shape, processed_signal_length)
with open('mel_spectrogram.bin', 'wb') as fp:
with open("../NemoOnnxSharp.Tests/Data/mel_spectrogram.bin", 'wb') as fp:
fp.write(processed_signal[0].T.numpy().tobytes("C"))


Expand All @@ -52,7 +55,7 @@ def convert_mfcc(audio_signal, length):
processed_signal, processed_signal_length = preprocessor(input_signal=audio_signal, length=length)
print(processed_signal, processed_signal_length)
print(processed_signal.shape, processed_signal_length)
with open('mfcc.bin', 'wb') as fp:
with open("../NemoOnnxSharp.Tests/Data/mfcc.bin", 'wb') as fp:
fp.write(processed_signal[0].T.numpy().tobytes("C"))


Expand Down
281 changes: 0 additions & 281 deletions Python/quartznet_15x5.yaml

This file was deleted.

0 comments on commit f285765

Please sign in to comment.