Skip to content

Commit

Permalink
Fix special tokens (#93)
Browse files Browse the repository at this point in the history
* Update version to 0.2.4 in buildspec.json

* Update special token handling in whisper-processing.cpp

* Update special token handling in whisper-processing.cpp
  • Loading branch information
royshil authored Apr 26, 2024
1 parent f36f6ec commit 3b955e3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
2 changes: 1 addition & 1 deletion buildspec.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
},
"name": "obs-localvocal",
"displayName": "OBS Localvocal",
"version": "0.2.3",
"version": "0.2.4",
"author": "Roy Shilkrot",
"website": "https://github.com/occ-ai/obs-localvocal",
"email": "[email protected]",
Expand Down
14 changes: 14 additions & 0 deletions src/transcription-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <sstream>
#include <algorithm>
#include <vector>

#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0)
#define is_trail_byte(c) (((c)&0xc0) == 0x80)
Expand Down Expand Up @@ -102,3 +103,16 @@ std::string remove_leading_trailing_nonalpha(const std::string &str)
}));
return str_copy;
}

std::vector<std::string> split(const std::string &string, char delimiter)
{
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(string);
while (std::getline(tokenStream, token, delimiter)) {
if (!token.empty()) {
tokens.push_back(token);
}
}
return tokens;
}
2 changes: 2 additions & 0 deletions src/transcription-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#define TRANSCRIPTION_UTILS_H

#include <string>
#include <vector>

std::string fix_utf8(const std::string &str);
std::string remove_leading_trailing_nonalpha(const std::string &str);
std::vector<std::string> split(const std::string &string, char delimiter);

#endif // TRANSCRIPTION_UTILS_H
23 changes: 13 additions & 10 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "transcription-filter-data.h"
#include "whisper-processing.h"
#include "whisper-utils.h"
#include "transcription-utils.h"

#include <algorithm>
#include <cctype>
Expand Down Expand Up @@ -282,6 +283,10 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') {
keep = false;
}
// if this is a special token, don't keep it
if (token.id >= 50256) {
keep = false;
}
if ((j == n_tokens - 2 || j == n_tokens - 3) && token.p < 0.5) {
keep = false;
}
Expand Down Expand Up @@ -312,20 +317,18 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter

// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
std::string suppress_sentences_copy = gf->suppress_sentences;
size_t pos = 0;
std::string token;
while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) {
token = suppress_sentences_copy.substr(0, pos);
suppress_sentences_copy.erase(0, pos + 1);
if (text == suppress_sentences_copy) {
obs_log(gf->log_level, "Suppressing sentence: %s",
// split the suppression list by newline into individual sentences
std::vector<std::string> suppress_sentences_list =
split(gf->suppress_sentences, '\n');
// check if the text is in the suppression list
for (const std::string &suppress_sentence : suppress_sentences_list) {
if (text.find(suppress_sentence) != std::string::npos) {
obs_log(gf->log_level, "Suppressed sentence: '%s'",
text.c_str());
return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}};
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
}
}
}

if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text.c_str());
Expand Down

0 comments on commit 3b955e3

Please sign in to comment.