Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade silero vad v5 #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified data/models/silero-vad/silero_vad.onnx
Binary file not shown.
1 change: 1 addition & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct transcription_filter_data {

/* Resampler */
audio_resampler_t *resampler_to_whisper;
struct circlebuf resampled_buffer;

/* whisper */
std::string whisper_model_path;
Expand Down
4 changes: 4 additions & 0 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
// calculate timestamp offset from the start of the stream
info.timestamp_offset_ns = now_ns() - gf->start_timestamp_ms * 1000000;
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
gf->wshiper_thread_cv.notify_one();
}

return audio;
Expand Down Expand Up @@ -154,6 +155,8 @@ void transcription_filter_destroy(void *data)
}
circlebuf_free(&gf->info_buffer);

circlebuf_free(&gf->resampled_buffer);

if (gf->captions_monitor.isEnabled()) {
gf->captions_monitor.stopThread();
}
Expand Down Expand Up @@ -444,6 +447,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
}
circlebuf_init(&gf->info_buffer);
circlebuf_init(&gf->whisper_buffer);
circlebuf_init(&gf->resampled_buffer);

// allocate copy buffers
gf->copy_buffers[0] =
Expand Down
32 changes: 13 additions & 19 deletions src/whisper-utils/silero-vad-onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ void VadIterator::init_onnx_model(const SileroString &model_path)
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};

void VadIterator::reset_states(bool reset_hc)
void VadIterator::reset_states(bool reset_state)
{
if (reset_hc) {
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
if (reset_state) {
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggered = false;
}
temp_end = 0;
Expand All @@ -115,19 +115,16 @@ float VadIterator::predict_one(const std::vector<float> &data)
input.assign(data.begin(), data.end());
Ort::Value input_ort = Ort::Value::CreateTensor<float>(memory_info, input.data(),
input.size(), input_node_dims, 2);
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(memory_info, sr.data(), sr.size(),
sr_node_dims, 1);
Ort::Value h_ort =
Ort::Value::CreateTensor<float>(memory_info, _h.data(), _h.size(), hc_node_dims, 3);
Ort::Value c_ort =
Ort::Value::CreateTensor<float>(memory_info, _c.data(), _c.size(), hc_node_dims, 3);

// Clear and add inputs
ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_ort));
ort_inputs.emplace_back(std::move(state_ort));
ort_inputs.emplace_back(std::move(sr_ort));
ort_inputs.emplace_back(std::move(h_ort));
ort_inputs.emplace_back(std::move(c_ort));

// Infer
ort_outputs = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(),
Expand All @@ -136,10 +133,8 @@ float VadIterator::predict_one(const std::vector<float> &data)

// Output probability & update h,c recursively
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float *hn = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
float *cn = ort_outputs[2].GetTensorMutableData<float>();
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float));

return speech_prob;
}
Expand Down Expand Up @@ -264,9 +259,9 @@ void VadIterator::predict(const std::vector<float> &data)
}
};

void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)
void VadIterator::process(const std::vector<float> &input_wav, bool reset_state)
{
reset_states(reset_hc);
reset_states(reset_state);

audio_length_samples = (int)input_wav.size();

Expand All @@ -290,7 +285,7 @@ void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)

void VadIterator::process(const std::vector<float> &input_wav, std::vector<float> &output_wav)
{
process(input_wav, true);
process(input_wav);
collect_chunks(input_wav, output_wav);
}

Expand Down Expand Up @@ -352,8 +347,7 @@ VadIterator::VadIterator(const SileroString &ModelPath, int Sample_rate, int win
input_node_dims[0] = 1;
input_node_dims[1] = window_size_samples;

_h.resize(size_hc);
_c.resize(size_hc);
_state.resize(size_state);
sr.resize(1);
sr[0] = sample_rate;
};
23 changes: 12 additions & 11 deletions src/whisper-utils/silero-vad-onnx.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,20 @@ class VadIterator {
private:
void init_engine_threads(int inter_threads, int intra_threads);
void init_onnx_model(const SileroString &model_path);
void reset_states(bool reset_hc);
void reset_states(bool reset_state);
float predict_one(const std::vector<float> &data);
void predict(const std::vector<float> &data);

public:
void process(const std::vector<float> &input_wav, bool reset_hc = true);
void process(const std::vector<float> &input_wav, bool reset_state = true);
void process(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
const std::vector<timestamp_t> get_speech_timestamps() const;
void drop_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void set_threshold(float threshold_) { this->threshold = threshold_; }

int64_t get_window_size_samples() const { return window_size_samples; }

private:
// model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
Expand Down Expand Up @@ -84,27 +86,26 @@ class VadIterator {
// Inputs
std::vector<Ort::Value> ort_inputs;

std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
std::vector<const char *> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
std::vector<float> _state;
std::vector<int64_t> sr;
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
std::vector<float> _h;
std::vector<float> _c;

int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
const int64_t hc_node_dims[3] = {2, 1, 64};

// Outputs
std::vector<Ort::Value> ort_outputs;
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
std::vector<const char *> output_node_names = {"output", "stateN"};

public:
// Construction
VadIterator(const SileroString &ModelPath, int Sample_rate = 16000,
int windows_frame_size = 64, float Threshold = 0.5,
int min_silence_duration_ms = 0, int speech_pad_ms = 64,
int min_speech_duration_ms = 64,
int windows_frame_size = 32, float Threshold = 0.5,
int min_silence_duration_ms = 0, int speech_pad_ms = 32,
int min_speech_duration_ms = 32,
float max_speech_duration_s = std::numeric_limits<float>::infinity());

// Default constructor
Expand Down
84 changes: 50 additions & 34 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <obs-module.h>

#include <util/profiler.hpp>

#include "plugin-support.h"
#include "transcription-filter-data.h"
#include "whisper-processing.h"
Expand Down Expand Up @@ -389,22 +391,43 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
num_frames_from_infos, overlap_size);
gf->last_num_frames = num_frames_from_infos + overlap_size;

// resample to 16kHz
float *resampled_16khz[MAX_PREPROC_CHANNELS];
uint32_t resampled_16khz_frames;
uint64_t ts_offset;
audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz,
&resampled_16khz_frames, &ts_offset,
(const uint8_t **)gf->copy_buffers,
(uint32_t)num_frames_from_infos);
{
// resample to 16kHz
float *resampled_16khz[MAX_PREPROC_CHANNELS];
uint32_t resampled_16khz_frames;
uint64_t ts_offset;
{
ProfileScope("resample");
audio_resampler_resample(gf->resampler_to_whisper,
(uint8_t **)resampled_16khz,
&resampled_16khz_frames, &ts_offset,
(const uint8_t **)gf->copy_buffers,
(uint32_t)num_frames_from_infos);
}

obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms",
(int)gf->channels, (int)resampled_16khz_frames,
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0],
resampled_16khz_frames * sizeof(float));
}

if (gf->resampled_buffer.size < (gf->vad->get_window_size_samples() * sizeof(float)))
return last_vad_state;

obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", (int)gf->channels,
(int)resampled_16khz_frames,
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
size_t len =
gf->resampled_buffer.size / (gf->vad->get_window_size_samples() * sizeof(float));

std::vector<float> vad_input(resampled_16khz[0],
resampled_16khz[0] + resampled_16khz_frames);
gf->vad->process(vad_input, false);
std::vector<float> vad_input;
vad_input.resize(len * gf->vad->get_window_size_samples());
circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(),
vad_input.size() * sizeof(float));

obs_log(gf->log_level, "sending %d frames to vad", vad_input.size());
{
ProfileScope("vad->process");
gf->vad->process(vad_input, !last_vad_state.vad_on);
}

const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000;
const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000;
Expand All @@ -414,8 +437,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v

std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
if (stamps.size() == 0) {
obs_log(gf->log_level, "VAD detected no speech in %d frames",
resampled_16khz_frames);
obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size());
if (last_vad_state.vad_on) {
obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference");
run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms,
Expand All @@ -425,7 +447,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
}

if (gf->enable_audio_chunks_callback) {
audio_chunk_callback(gf, resampled_16khz[0], resampled_16khz_frames,
audio_chunk_callback(gf, vad_input.data(), vad_input.size(),
VAD_STATE_IS_OFF,
{DETECTION_RESULT_SILENCE,
"[silence]",
Expand All @@ -449,16 +471,16 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
}

int end_frame = stamps[i].end;
if (i == stamps.size() - 1 && stamps[i].end < (int)resampled_16khz_frames) {
if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) {
// take at least 100ms of audio after the last speech segment, if available
end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10,
(int)resampled_16khz_frames);
(int)vad_input.size());
}

const int number_of_frames = end_frame - start_frame;

// push the data into gf-whisper_buffer
circlebuf_push_back(&gf->whisper_buffer, resampled_16khz[0] + start_frame,
circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame,
number_of_frames * sizeof(float));

obs_log(gf->log_level,
Expand All @@ -469,7 +491,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE);

// segment "end" is in the middle of the buffer, send it to inference
if (stamps[i].end < (int)resampled_16khz_frames) {
if (stamps[i].end < (int)vad_input.size()) {
// new "ending" segment (not up to the end of the buffer)
obs_log(gf->log_level, "VAD segment end -> send to inference");
// find the end timestamp of the segment
Expand Down Expand Up @@ -542,30 +564,24 @@ void whisper_loop(void *data)
obs_log(gf->log_level, "Starting whisper thread");

vad_state current_vad_state = {false, 0, 0, 0};
// 500 ms worth of audio is needed for VAD segmentation
uint32_t min_num_bytes_for_vad = (gf->sample_rate / 2) * sizeof(float);

const char *whisper_loop_name = "Whisper loop";
profile_register_root(whisper_loop_name, 50 * 1000 * 1000);

// Thread main loop
while (true) {
ProfileScope(whisper_loop_name);
{
ProfileScope("lock whisper ctx");
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
ProfileScope("locked whisper ctx");
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
break;
}
}

uint32_t num_bytes_on_input = 0;
{
// scoped lock the buffer mutex
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
num_bytes_on_input = (uint32_t)gf->input_buffers[0].size;
}

// only run vad segmentation if there are at least 500 ms of audio in the buffer
if (num_bytes_on_input > min_num_bytes_for_vad) {
current_vad_state = vad_based_segmentation(gf, current_vad_state);
}
current_vad_state = vad_based_segmentation(gf, current_vad_state);

if (!gf->cleared_last_sub) {
// check if we should clear the current sub depending on the minimum subtitle duration
Expand Down
Loading