From 3af0bf4f4091693e67ec34de7a2113ef8da117b3 Mon Sep 17 00:00:00 2001 From: Katsuya Iida Date: Mon, 9 Oct 2023 13:23:03 +0900 Subject: [PATCH] Add SpeechRecognizer (#19) --- NeMoOnnxSharp.Example/Program.cs | 81 ++++++++ NeMoOnnxSharp/SpeechRecognitionEventArgs.cs | 24 +++ NeMoOnnxSharp/SpeechRecognizer.cs | 198 ++++++++++++++++++++ test_data/.gitignore | 2 + 4 files changed, 305 insertions(+) create mode 100644 NeMoOnnxSharp/SpeechRecognitionEventArgs.cs create mode 100644 NeMoOnnxSharp/SpeechRecognizer.cs diff --git a/NeMoOnnxSharp.Example/Program.cs b/NeMoOnnxSharp.Example/Program.cs index 006a059..8194f46 100644 --- a/NeMoOnnxSharp.Example/Program.cs +++ b/NeMoOnnxSharp.Example/Program.cs @@ -6,6 +6,7 @@ using System.Text; using System.Threading.Tasks; using System.Collections.Generic; +using System.Runtime.InteropServices; namespace NeMoOnnxSharp.Example { @@ -31,6 +32,10 @@ static async Task Main(string[] args) { await FramePredict(true); } + else if (task == "streamaudio") + { + await StreamAudio(); + } else { throw new InvalidDataException(task); @@ -43,6 +48,7 @@ static async Task Transcribe() string modelPath = await DownloadModelAsync("stt_en_quartznet15x5"); string inputDirPath = Path.Combine(appDirPath, "..", "..", "..", "..", "test_data"); string inputPath = Path.Combine(inputDirPath, "transcript.txt"); + using var model = new EncDecCTCModel(modelPath); using var reader = File.OpenText(inputPath); string? line; @@ -110,6 +116,81 @@ static async Task FramePredict(bool mbn) } } + static async Task StreamAudio() + { + string appDirPath = AppDomain.CurrentDomain.BaseDirectory; + string vadModelPath = await DownloadModelAsync("vad_marblenet"); + string asrModelPath = await DownloadModelAsync("stt_en_quartznet15x5"); + string inputDirPath = Path.Combine(appDirPath, "..", "..", "..", "..", "test_data"); + string inputPath = Path.Combine(inputDirPath, "transcript.txt"); + + using var recognizer = new SpeechRecognizer(vadModelPath, asrModelPath); + using var ostream = new FileStream(Path.Combine(inputDirPath, "result.txt"), FileMode.Create); + using var writer = new StreamWriter(ostream); + int index = 0; + recognizer.SpeechStartDetected += (s, e) => + { + double t = (double)e.Offset / recognizer.SampleRate; + Console.WriteLine("SpeechStartDetected {0}", t); + }; + recognizer.SpeechEndDetected += (s, e) => + { + double t = (double)e.Offset / recognizer.SampleRate; + Console.WriteLine("SpeechEndDetected {0}", t); + }; + recognizer.Recognized += (s, e) => + { + double t = (double)e.Offset / recognizer.SampleRate; + Console.WriteLine("Recognized {0} {1} {2}", t, e.Audio?.Length, e.Text); + string fileName = string.Format("recognized-{0}.wav", index); + writer.WriteLine("{0}|{1}|{2}", fileName, e.Audio?.Length, e.Text); + if (e.Audio != null) + { + WaveFile.WriteWAV(Path.Combine(inputDirPath, fileName), e.Audio, recognizer.SampleRate); + } + index += 1; + }; + var stream = GetAllAudioStream(inputDirPath); + var buffer = new byte[1024]; + while (true) + { + int bytesRead = stream.Read(buffer); + if (bytesRead == 0) + { + break; + } + recognizer.Write(buffer.AsSpan(0, bytesRead)); + } + } + + private static MemoryStream GetAllAudioStream( + string inputDirPath, + int sampleRate = 16000, + double gapSeconds = 1.0) + { + string inputPath = Path.Combine(inputDirPath, "transcript.txt"); + using var reader = File.OpenText(inputPath); + string? line; + var stream = new MemoryStream(); + var waveform = new short[(int)(sampleRate * gapSeconds)]; + var bytes = MemoryMarshal.Cast(waveform); + stream.Write(bytes); + while ((line = reader.ReadLine()) != null) + { + string[] parts = line.Split("|"); + string name = parts[0]; + string waveFile = Path.Combine(inputDirPath, name); + waveform = WaveFile.ReadWAV(waveFile, sampleRate); + bytes = MemoryMarshal.Cast(waveform); + stream.Write(bytes); + waveform = new short[(int)(sampleRate * gapSeconds)]; + bytes = MemoryMarshal.Cast(waveform); + stream.Write(bytes); + } + stream.Seek(0, SeekOrigin.Begin); + return stream; + } + private static async Task DownloadModelAsync(string model) { using var downloader = new ModelDownloader(); diff --git a/NeMoOnnxSharp/SpeechRecognitionEventArgs.cs b/NeMoOnnxSharp/SpeechRecognitionEventArgs.cs new file mode 100644 index 0000000..8f5e3ad --- /dev/null +++ b/NeMoOnnxSharp/SpeechRecognitionEventArgs.cs @@ -0,0 +1,24 @@ +// Copyright (c) Katsuya Iida. All Rights Reserved. +// See LICENSE in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace NeMoOnnxSharp +{ + public class SpeechRecognitionEventArgs + { + public SpeechRecognitionEventArgs(ulong offset, string? text = null, short[]? audio = null) + { + Offset = offset; + Text = text; + Audio = audio; + } + + public ulong Offset { get; private set; } + public string? Text { get; private set; } + public short[]? Audio { get; private set; } + } +} diff --git a/NeMoOnnxSharp/SpeechRecognizer.cs b/NeMoOnnxSharp/SpeechRecognizer.cs new file mode 100644 index 0000000..96193b6 --- /dev/null +++ b/NeMoOnnxSharp/SpeechRecognizer.cs @@ -0,0 +1,198 @@ +// Copyright (c) Katsuya Iida. All Rights Reserved. +// See LICENSE in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; + +namespace NeMoOnnxSharp +{ + public class SpeechRecognizer : IDisposable + { + public delegate void SpeechStart(long position); + public delegate void SpeechEnd(long position, short[] audioSignal, string? transcript); + + private readonly FrameVAD _frameVad; + private readonly EncDecCTCModel _asrModel; + private readonly int _audioBufferIncrease; + private readonly int _audioBufferSize; + int _audioBufferIndex; + long _currentPosition; + byte[] _audioBuffer; + bool _isSpeech; + private readonly float _speechStartThreadhold; + private readonly float _speechEndThreadhold; + + private SpeechRecognizer(FrameVAD frameVad, EncDecCTCModel asrModel) + { + _frameVad = frameVad; + _asrModel = asrModel; + _currentPosition = 0; + _audioBufferIndex = 0; + _audioBufferSize = sizeof(short) * _frameVad.SampleRate * 2; // 2sec + _audioBufferIncrease = sizeof(short) * 5 * _frameVad.SampleRate; // 10sec + _audioBuffer = new byte[_audioBufferSize]; + _isSpeech = false; + _speechStartThreadhold = 0.7f; + _speechEndThreadhold = 0.3f; + } + + public SpeechRecognizer(string vadModelPath, string asrModelPath) : this( + new FrameVAD(vadModelPath), new EncDecCTCModel(asrModelPath)) + { + } + + public SpeechRecognizer(byte[] vadModel, byte[] asrModel) : this( + new FrameVAD(vadModel), new EncDecCTCModel(asrModel)) + { + } + + public int SampleRate => _frameVad.SampleRate; + public event EventHandler? Recognized; + public event EventHandler? SpeechStartDetected; + public event EventHandler? SpeechEndDetected; + + public void Dispose() + { + _frameVad.Dispose(); + } + + public void Write(byte[] input, int offset, int count) + { + Write(input.AsSpan(offset, count)); + } + + public void Write(Span input) + { + while (input.Length > 0) + { + int len = input.Length; + if (_isSpeech) + { + if (len > _audioBuffer.Length - _audioBufferIndex) + { + var tmp = new byte[_audioBuffer.Length + _audioBufferIncrease]; + Array.Copy(_audioBuffer, tmp, _audioBufferIndex); + _audioBuffer = tmp; + } + } + else + { + if (_audioBufferIndex >= _audioBuffer.Length) + { + _audioBufferIndex = 0; + } + len = Math.Min(_audioBuffer.Length - _audioBufferIndex, len); + } + input.Slice(0, len).CopyTo(_audioBuffer.AsSpan(_audioBufferIndex, len)); + input = input.Slice(len); + int len2 = (len / sizeof(short)) * sizeof(short); + var audioSignal = MemoryMarshal.Cast(_audioBuffer.AsSpan(_audioBufferIndex, len2)); + _audioBufferIndex += len; + _currentPosition += audioSignal.Length; + _Transcribe(audioSignal); + } + } + + private void _Transcribe(Span audioSignal) + { + var pos = -(audioSignal.Length + _frameVad.PredictionOffset); + var result = _frameVad.Transcribe(audioSignal); + foreach (var prob in result) + { + if (_isSpeech) + { + if (prob < _speechEndThreadhold) + { + _isSpeech = false; + int posBytes = pos * sizeof(short); + if (Recognized != null) + { + var audio = _audioBuffer.AsSpan(0, _audioBufferIndex + posBytes); + var x = MemoryMarshal.Cast(audio).ToArray(); + string predictText = _asrModel.Transcribe(x); + Recognized(this, new SpeechRecognitionEventArgs( + (ulong)(_currentPosition + pos), predictText, x)); + } + if (SpeechEndDetected != null) + { + SpeechEndDetected(this, new SpeechRecognitionEventArgs( + (ulong)(_currentPosition + pos))); + } + _ResetAudioBuffer(posBytes); + } + } + else + { + if (prob >= _speechStartThreadhold) + { + _isSpeech = true; + if (SpeechStartDetected != null) { + SpeechStartDetected(this, new SpeechRecognitionEventArgs( + (ulong)(_currentPosition + pos))); + } + int pos2 = pos * sizeof(short); + _ChangeAudioBufferForSpeech(pos2); + } + } + pos += _frameVad.HopLength; + } + } + + private void _ResetAudioBuffer(int posBytes) + { + var tmp = new byte[_audioBufferSize]; + Array.Copy( + _audioBuffer, _audioBufferIndex + posBytes, + tmp, 0, + -posBytes); + _audioBuffer = tmp; + _audioBufferIndex = -posBytes; + } + + private void _ChangeAudioBufferForSpeech(int posBytes) + { + int audioBufferStart = _audioBufferIndex + posBytes; + int audioBufferEnd = _audioBufferIndex; + if (audioBufferStart >= 0) + { + Array.Copy( + _audioBuffer, audioBufferStart, + _audioBuffer, 0, + audioBufferEnd - audioBufferStart); + _audioBufferIndex = audioBufferEnd - audioBufferStart; + } + else if (audioBufferStart + _audioBuffer.Length >= audioBufferEnd) + { + var tmp = new byte[_audioBuffer.Length + _audioBufferIncrease]; + Array.Copy( + _audioBuffer, audioBufferStart + _audioBuffer.Length, + tmp, 0, + -audioBufferStart); + Array.Copy( + _audioBuffer, 0, + tmp, -audioBufferStart, + audioBufferEnd); + _audioBuffer = tmp; + _audioBufferIndex = audioBufferEnd - audioBufferStart; + } + else + { + var tmp = new byte[_audioBuffer.Length + _audioBufferIncrease]; + Array.Copy( + _audioBuffer, audioBufferEnd, + tmp, 0, + _audioBuffer.Length - audioBufferEnd); + Array.Copy( + _audioBuffer, 0, + tmp, _audioBuffer.Length - audioBufferEnd, + audioBufferEnd); + _audioBuffer = tmp; + _audioBufferIndex = _audioBuffer.Length; + } + } + } +} diff --git a/test_data/.gitignore b/test_data/.gitignore index 12ad52c..16b6a36 100644 --- a/test_data/.gitignore +++ b/test_data/.gitignore @@ -1 +1,3 @@ /generated-*.wav +/recognized-*.wav +/result.txt \ No newline at end of file