Skip to content

Commit

Permalink
refactor: Update VAD threshold in transcription filter
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed Jun 5, 2024
1 parent dab16d2 commit 5cf661d
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 39 deletions.
1 change: 1 addition & 0 deletions data/locale/en-US.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
LocalVocalPlugin="LocalVocal Plugin"
transcription_filterAudioFilter="LocalVocal Transcription"
vad_enabled="VAD Enabled"
vad_threshold="VAD Threshold"
log_level="Internal Log Level"
log_words="Log Output to Console"
caption_to_stream="Stream Captions"
Expand Down
12 changes: 10 additions & 2 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");

if (gf->vad_enabled && gf->vad) {
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
gf->vad->set_threshold(vad_threshold);
}
}

void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
Expand Down Expand Up @@ -524,6 +529,7 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_int(s, "buffer_num_chars_per_line", 30);

obs_data_set_default_bool(s, "vad_enabled", true);
obs_data_set_default_double(s, "vad_threshold", 0.5);
obs_data_set_default_int(s, "log_level", LOG_DEBUG);
obs_data_set_default_bool(s, "log_words", false);
obs_data_set_default_bool(s, "caption_to_stream", false);
Expand Down Expand Up @@ -799,7 +805,7 @@ obs_properties_t *transcription_filter_properties(void *data)
{"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec",
"overlap_size_msec", "step_by_step_processing", "min_sub_duration",
"process_while_muted", "buffered_output", "vad_enabled", "log_level",
"suppress_sentences", "sentence_psum_accept_thresh", "send_timed_metadata"}) {
"suppress_sentences", "sentence_psum_accept_thresh"}) {
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
show_hide);
}
Expand Down Expand Up @@ -834,7 +840,6 @@ obs_properties_t *transcription_filter_properties(void *data)

obs_properties_add_bool(ppts, "log_words", MT_("log_words"));
obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream"));
obs_properties_add_bool(ppts, "send_timed_metadata", MT_("send_timed_metadata"));

obs_properties_add_int_slider(ppts, "min_sub_duration", MT_("min_sub_duration"), 1000, 5000,
50);
Expand All @@ -844,6 +849,9 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_properties_add_bool(ppts, "process_while_muted", MT_("process_while_muted"));

obs_properties_add_bool(ppts, "vad_enabled", MT_("vad_enabled"));
// add vad threshold slider
obs_properties_add_float_slider(ppts, "vad_threshold", MT_("vad_threshold"), 0.0, 1.0,
0.05);

obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
Expand Down
1 change: 1 addition & 0 deletions src/whisper-utils/silero-vad-onnx.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class VadIterator {
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
const std::vector<timestamp_t> get_speech_timestamps() const;
void drop_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void set_threshold(float threshold) { this->threshold = threshold; }

private:
// model config
Expand Down
69 changes: 47 additions & 22 deletions src/whisper-utils/token-buffer-thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
#define NEWLINE "\n"
#endif

TokenBufferThread::TokenBufferThread() noexcept
: gf(nullptr),
numSentences(1),
numPerSentence(1),
maxTime(0),
stop(true)
{
}

TokenBufferThread::~TokenBufferThread()
{
{
std::lock_guard<std::mutex> lock(inputQueueMutex);
stop = true;
}
condVar.notify_all();
if (workerThread.joinable()) {
workerThread.join();
}
stopThread();
}

void TokenBufferThread::initialize(struct transcription_filter_data *gf_,
Expand All @@ -41,13 +43,18 @@ void TokenBufferThread::initialize(struct transcription_filter_data *gf_,
this->segmentation = segmentation_;
this->maxTime = maxTime_;
this->stop = false;
this->presentationQueueMutex = std::make_unique<std::mutex>();
this->inputQueueMutex = std::make_unique<std::mutex>();
this->workerThread = std::thread(&TokenBufferThread::monitor, this);
}

void TokenBufferThread::stopThread()
{
std::lock_guard<std::mutex> lock(inputQueueMutex);
stop = true;
{
std::lock_guard<std::mutex> lock(*inputQueueMutex);
std::lock_guard<std::mutex> lockPresentation(*presentationQueueMutex);
stop = true;
}
condVar.notify_all();
if (workerThread.joinable()) {
workerThread.join();
Expand Down Expand Up @@ -85,7 +92,7 @@ void TokenBufferThread::addSentence(const std::string &sentence)
}
#endif

std::lock_guard<std::mutex> lock(inputQueueMutex);
std::lock_guard<std::mutex> lock(*inputQueueMutex);

// add the reconstructed sentence to the wordQueue
for (const auto &character : characters) {
Expand All @@ -97,11 +104,11 @@ void TokenBufferThread::addSentence(const std::string &sentence)
void TokenBufferThread::clear()
{
{
std::lock_guard<std::mutex> lock(inputQueueMutex);
std::lock_guard<std::mutex> lock(*inputQueueMutex);
inputQueue.clear();
}
{
std::lock_guard<std::mutex> lock(presentationQueueMutex);
std::lock_guard<std::mutex> lock(*presentationQueueMutex);
presentationQueue.clear();
}
this->callback("");
Expand All @@ -114,8 +121,14 @@ void TokenBufferThread::monitor()
this->callback("");

while (!this->stop) {
std::string caption_out;

if (presentationQueueMutex == nullptr) {
break;
}

{
std::unique_lock<std::mutex> lockPresentation(this->presentationQueueMutex);
std::lock_guard<std::mutex> lockPresentation(*presentationQueueMutex);
// condition presentation queue
if (presentationQueue.size() == this->numSentences * this->numPerSentence) {
// pop a whole sentence from the presentation queue front
Expand All @@ -125,7 +138,11 @@ void TokenBufferThread::monitor()
}

{
std::unique_lock<std::mutex> lock(this->inputQueueMutex);
if (inputQueueMutex == nullptr) {
break;
}

std::lock_guard<std::mutex> lock(*inputQueueMutex);

if (!inputQueue.empty()) {
// if there are token on the input queue
Expand Down Expand Up @@ -194,22 +211,30 @@ void TokenBufferThread::monitor()
int count = WideCharToMultiByte(CP_UTF8, 0, caption.c_str(),
(int)caption.length(), NULL, 0,
NULL, NULL);
std::string caption_out(count, 0);
caption_out = std::string(count, 0);
WideCharToMultiByte(CP_UTF8, 0, caption.c_str(),
(int)caption.length(), &caption_out[0], count,
NULL, NULL);
#else
std::string caption_out(caption.begin(), caption.end());
caption_out = std::string(caption.begin(), caption.end());
#endif

// emit the caption
this->callback(caption_out);
}
}

if (caption_out.empty()) {
// if no caption was built, sleep for a while
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}

// emit the caption
this->callback(caption_out);

// check the input queue size (iqs), if it's big - sleep less
std::this_thread::sleep_for(
std::chrono::milliseconds(inputQueue.size() > 15 ? 66 : 100));
std::this_thread::sleep_for(std::chrono::milliseconds(inputQueue.size() > 30 ? 33
: inputQueue.size() > 15
? 66
: 100));
}

obs_log(LOG_INFO, "TokenBufferThread::monitor: done");
Expand Down
6 changes: 3 additions & 3 deletions src/whisper-utils/token-buffer-thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ enum TokenBufferSegmentation { SEGMENTATION_WORD = 0, SEGMENTATION_TOKEN, SEGMEN
class TokenBufferThread {
public:
// default constructor
TokenBufferThread() = default;
TokenBufferThread() noexcept;

~TokenBufferThread();
void initialize(struct transcription_filter_data *gf,
Expand All @@ -51,8 +51,8 @@ class TokenBufferThread {
std::deque<TokenBufferString> inputQueue;
std::deque<TokenBufferString> presentationQueue;
std::thread workerThread;
std::mutex inputQueueMutex;
std::mutex presentationQueueMutex;
std::unique_ptr<std::mutex> inputQueueMutex;
std::unique_ptr<std::mutex> presentationQueueMutex;
std::condition_variable condVar;
std::function<void(std::string)> callback;
std::chrono::seconds maxTime;
Expand Down
13 changes: 1 addition & 12 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,9 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
uint32_t num_frames_from_infos = 0;
uint64_t start_timestamp = 0;
uint64_t end_timestamp = 0;
size_t overlap_size = 0; //gf->sample_rate / 10;
size_t overlap_size = 0;

for (size_t c = 0; c < gf->channels; c++) {
// if (!current_vad_on && gf->last_num_frames > overlap_size) {
// if (c == 0) {
// // print only once
// obs_log(gf->log_level, "VAD overlap: %lu frames", overlap_size);
// }
// // move 100ms from the end of copy_buffers to the beginning
// memmove(gf->copy_buffers[c], gf->copy_buffers[c] + gf->last_num_frames - overlap_size,
// overlap_size * sizeof(float));
// } else {
// overlap_size = 0;
// }
// zero the rest of copy_buffers
memset(gf->copy_buffers[c] + overlap_size, 0,
(gf->frames - overlap_size) * sizeof(float));
Expand Down

0 comments on commit 5cf661d

Please sign in to comment.