Skip to content

Commit

Permalink
Add SpeechRecognizer (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiidams authored Oct 9, 2023
1 parent 7925b96 commit 3af0bf4
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 0 deletions.
81 changes: 81 additions & 0 deletions NeMoOnnxSharp.Example/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Text;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Runtime.InteropServices;

namespace NeMoOnnxSharp.Example
{
Expand All @@ -31,6 +32,10 @@ static async Task Main(string[] args)
{
await FramePredict(true);
}
else if (task == "streamaudio")
{
await StreamAudio();
}
else
{
throw new InvalidDataException(task);
Expand All @@ -43,6 +48,7 @@ static async Task Transcribe()
string modelPath = await DownloadModelAsync("stt_en_quartznet15x5");
string inputDirPath = Path.Combine(appDirPath, "..", "..", "..", "..", "test_data");
string inputPath = Path.Combine(inputDirPath, "transcript.txt");

using var model = new EncDecCTCModel(modelPath);
using var reader = File.OpenText(inputPath);
string? line;
Expand Down Expand Up @@ -110,6 +116,81 @@ static async Task FramePredict(bool mbn)
}
}

static async Task StreamAudio()
{
string appDirPath = AppDomain.CurrentDomain.BaseDirectory;
string vadModelPath = await DownloadModelAsync("vad_marblenet");
string asrModelPath = await DownloadModelAsync("stt_en_quartznet15x5");
string inputDirPath = Path.Combine(appDirPath, "..", "..", "..", "..", "test_data");
string inputPath = Path.Combine(inputDirPath, "transcript.txt");

using var recognizer = new SpeechRecognizer(vadModelPath, asrModelPath);
using var ostream = new FileStream(Path.Combine(inputDirPath, "result.txt"), FileMode.Create);
using var writer = new StreamWriter(ostream);
int index = 0;
recognizer.SpeechStartDetected += (s, e) =>
{
double t = (double)e.Offset / recognizer.SampleRate;
Console.WriteLine("SpeechStartDetected {0}", t);
};
recognizer.SpeechEndDetected += (s, e) =>
{
double t = (double)e.Offset / recognizer.SampleRate;
Console.WriteLine("SpeechEndDetected {0}", t);
};
recognizer.Recognized += (s, e) =>
{
double t = (double)e.Offset / recognizer.SampleRate;
Console.WriteLine("Recognized {0} {1} {2}", t, e.Audio?.Length, e.Text);
string fileName = string.Format("recognized-{0}.wav", index);
writer.WriteLine("{0}|{1}|{2}", fileName, e.Audio?.Length, e.Text);
if (e.Audio != null)
{
WaveFile.WriteWAV(Path.Combine(inputDirPath, fileName), e.Audio, recognizer.SampleRate);
}
index += 1;
};
var stream = GetAllAudioStream(inputDirPath);
var buffer = new byte[1024];
while (true)
{
int bytesRead = stream.Read(buffer);
if (bytesRead == 0)
{
break;
}
recognizer.Write(buffer.AsSpan(0, bytesRead));
}
}

private static MemoryStream GetAllAudioStream(
string inputDirPath,
int sampleRate = 16000,
double gapSeconds = 1.0)
{
string inputPath = Path.Combine(inputDirPath, "transcript.txt");
using var reader = File.OpenText(inputPath);
string? line;
var stream = new MemoryStream();
var waveform = new short[(int)(sampleRate * gapSeconds)];
var bytes = MemoryMarshal.Cast<short, byte>(waveform);
stream.Write(bytes);
while ((line = reader.ReadLine()) != null)
{
string[] parts = line.Split("|");
string name = parts[0];
string waveFile = Path.Combine(inputDirPath, name);
waveform = WaveFile.ReadWAV(waveFile, sampleRate);
bytes = MemoryMarshal.Cast<short, byte>(waveform);
stream.Write(bytes);
waveform = new short[(int)(sampleRate * gapSeconds)];
bytes = MemoryMarshal.Cast<short, byte>(waveform);
stream.Write(bytes);
}
stream.Seek(0, SeekOrigin.Begin);
return stream;
}

private static async Task<string> DownloadModelAsync(string model)
{
using var downloader = new ModelDownloader();
Expand Down
24 changes: 24 additions & 0 deletions NeMoOnnxSharp/SpeechRecognitionEventArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Katsuya Iida. All Rights Reserved.
// See LICENSE in the project root for license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace NeMoOnnxSharp
{
public class SpeechRecognitionEventArgs
{
public SpeechRecognitionEventArgs(ulong offset, string? text = null, short[]? audio = null)
{
Offset = offset;
Text = text;
Audio = audio;
}

public ulong Offset { get; private set; }
public string? Text { get; private set; }
public short[]? Audio { get; private set; }
}
}
198 changes: 198 additions & 0 deletions NeMoOnnxSharp/SpeechRecognizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (c) Katsuya Iida. All Rights Reserved.
// See LICENSE in the project root for license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;

namespace NeMoOnnxSharp
{
public class SpeechRecognizer : IDisposable
{
public delegate void SpeechStart(long position);
public delegate void SpeechEnd(long position, short[] audioSignal, string? transcript);

private readonly FrameVAD _frameVad;
private readonly EncDecCTCModel _asrModel;
private readonly int _audioBufferIncrease;
private readonly int _audioBufferSize;
int _audioBufferIndex;
long _currentPosition;
byte[] _audioBuffer;
bool _isSpeech;
private readonly float _speechStartThreadhold;
private readonly float _speechEndThreadhold;

private SpeechRecognizer(FrameVAD frameVad, EncDecCTCModel asrModel)
{
_frameVad = frameVad;
_asrModel = asrModel;
_currentPosition = 0;
_audioBufferIndex = 0;
_audioBufferSize = sizeof(short) * _frameVad.SampleRate * 2; // 2sec
_audioBufferIncrease = sizeof(short) * 5 * _frameVad.SampleRate; // 10sec
_audioBuffer = new byte[_audioBufferSize];
_isSpeech = false;
_speechStartThreadhold = 0.7f;
_speechEndThreadhold = 0.3f;
}

public SpeechRecognizer(string vadModelPath, string asrModelPath) : this(
new FrameVAD(vadModelPath), new EncDecCTCModel(asrModelPath))
{
}

public SpeechRecognizer(byte[] vadModel, byte[] asrModel) : this(
new FrameVAD(vadModel), new EncDecCTCModel(asrModel))
{
}

public int SampleRate => _frameVad.SampleRate;
public event EventHandler<SpeechRecognitionEventArgs>? Recognized;
public event EventHandler<SpeechRecognitionEventArgs>? SpeechStartDetected;
public event EventHandler<SpeechRecognitionEventArgs>? SpeechEndDetected;

public void Dispose()
{
_frameVad.Dispose();
}

public void Write(byte[] input, int offset, int count)
{
Write(input.AsSpan(offset, count));
}

public void Write(Span<byte> input)
{
while (input.Length > 0)
{
int len = input.Length;
if (_isSpeech)
{
if (len > _audioBuffer.Length - _audioBufferIndex)
{
var tmp = new byte[_audioBuffer.Length + _audioBufferIncrease];
Array.Copy(_audioBuffer, tmp, _audioBufferIndex);
_audioBuffer = tmp;
}
}
else
{
if (_audioBufferIndex >= _audioBuffer.Length)
{
_audioBufferIndex = 0;
}
len = Math.Min(_audioBuffer.Length - _audioBufferIndex, len);
}
input.Slice(0, len).CopyTo(_audioBuffer.AsSpan(_audioBufferIndex, len));
input = input.Slice(len);
int len2 = (len / sizeof(short)) * sizeof(short);
var audioSignal = MemoryMarshal.Cast<byte, short>(_audioBuffer.AsSpan(_audioBufferIndex, len2));
_audioBufferIndex += len;
_currentPosition += audioSignal.Length;
_Transcribe(audioSignal);
}
}

private void _Transcribe(Span<short> audioSignal)
{
var pos = -(audioSignal.Length + _frameVad.PredictionOffset);
var result = _frameVad.Transcribe(audioSignal);
foreach (var prob in result)
{
if (_isSpeech)
{
if (prob < _speechEndThreadhold)
{
_isSpeech = false;
int posBytes = pos * sizeof(short);
if (Recognized != null)
{
var audio = _audioBuffer.AsSpan(0, _audioBufferIndex + posBytes);
var x = MemoryMarshal.Cast<byte, short>(audio).ToArray();
string predictText = _asrModel.Transcribe(x);
Recognized(this, new SpeechRecognitionEventArgs(
(ulong)(_currentPosition + pos), predictText, x));
}
if (SpeechEndDetected != null)
{
SpeechEndDetected(this, new SpeechRecognitionEventArgs(
(ulong)(_currentPosition + pos)));
}
_ResetAudioBuffer(posBytes);
}
}
else
{
if (prob >= _speechStartThreadhold)
{
_isSpeech = true;
if (SpeechStartDetected != null) {
SpeechStartDetected(this, new SpeechRecognitionEventArgs(
(ulong)(_currentPosition + pos)));
}
int pos2 = pos * sizeof(short);
_ChangeAudioBufferForSpeech(pos2);
}
}
pos += _frameVad.HopLength;
}
}

private void _ResetAudioBuffer(int posBytes)
{
var tmp = new byte[_audioBufferSize];
Array.Copy(
_audioBuffer, _audioBufferIndex + posBytes,
tmp, 0,
-posBytes);
_audioBuffer = tmp;
_audioBufferIndex = -posBytes;
}

private void _ChangeAudioBufferForSpeech(int posBytes)
{
int audioBufferStart = _audioBufferIndex + posBytes;
int audioBufferEnd = _audioBufferIndex;
if (audioBufferStart >= 0)
{
Array.Copy(
_audioBuffer, audioBufferStart,
_audioBuffer, 0,
audioBufferEnd - audioBufferStart);
_audioBufferIndex = audioBufferEnd - audioBufferStart;
}
else if (audioBufferStart + _audioBuffer.Length >= audioBufferEnd)
{
var tmp = new byte[_audioBuffer.Length + _audioBufferIncrease];
Array.Copy(
_audioBuffer, audioBufferStart + _audioBuffer.Length,
tmp, 0,
-audioBufferStart);
Array.Copy(
_audioBuffer, 0,
tmp, -audioBufferStart,
audioBufferEnd);
_audioBuffer = tmp;
_audioBufferIndex = audioBufferEnd - audioBufferStart;
}
else
{
var tmp = new byte[_audioBuffer.Length + _audioBufferIncrease];
Array.Copy(
_audioBuffer, audioBufferEnd,
tmp, 0,
_audioBuffer.Length - audioBufferEnd);
Array.Copy(
_audioBuffer, 0,
tmp, _audioBuffer.Length - audioBufferEnd,
audioBufferEnd);
_audioBuffer = tmp;
_audioBufferIndex = _audioBuffer.Length;
}
}
}
}
2 changes: 2 additions & 0 deletions test_data/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
/generated-*.wav
/recognized-*.wav
/result.txt

0 comments on commit 3af0bf4

Please sign in to comment.