Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vad reconfig #922

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 87 additions & 29 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from collections import deque
from concurrent.futures import ThreadPoolExecutor
import threading
import uuid
from datetime import datetime, timezone
from enum import Enum

import numpy as np
import opuslib
from fastapi import APIRouter
from fastapi.websockets import WebSocketDisconnect, WebSocket
from pydub import AudioSegment
from starlette.websockets import WebSocketState

from database.redis_db import get_user_speech_profile, get_user_speech_profile_duration
from utils.stt.streaming import process_audio_dg, send_initial_file
from utils.stt.vad import VADIterator, model, is_speech_present
import database.memories as memories_db
import database.processing_memories as processing_memories_db
from models.memory import Memory, TranscriptSegment
Expand All @@ -21,6 +27,7 @@

router = APIRouter()

thread_pool = ThreadPoolExecutor(max_workers=4)

# @router.post("/v1/transcribe", tags=['v1'])
# will be used again in Friend V2
Expand Down Expand Up @@ -222,70 +229,122 @@ async def deepgram_socket_send(data):
await websocket.close(code=websocket_close_code)
return

vad_iterator = VADIterator(model, sampling_rate=sample_rate) # threshold=0.9
threshold = 0.7
vad_iterator = VADIterator(model, sampling_rate=sample_rate, threshold=threshold)
window_size_samples = 256 if sample_rate == 8000 else 512
window_size_bytes = int(window_size_samples * 2 * 2.5)
window_size_bytes = int(window_size_samples * 2 * 2.5 * 10)

decoder = opuslib.Decoder(sample_rate, channels)
if codec == 'opus':
decoder = opuslib.Decoder(sample_rate, channels)

async def receive_audio(dg_socket1, dg_socket2, soniox_socket, speechmatics_socket1):
nonlocal websocket_active
nonlocal websocket_close_code
nonlocal timer_start
nonlocal decoder
speech_timeout = 2.0
DEFAULT_VAD_CHECK_INTERVAL_NONACTIVE = 0.1
opus_vad_check_interval_nonactive = DEFAULT_VAD_CHECK_INTERVAL_NONACTIVE
opus_vad_check_interval_active = 0.250
timer_start = time.time()

# nonlocal audio_buffer
# audio_buffer = bytearray()
# speech_state = SpeechState.no_speech

is_speech_active = False
last_vad_check_time = 0
last_speech_time = 0
start_speech_inactivity_time = 0
soniox_sendanything_interval = 7
must_send_data = False
audio_buffer = deque(maxlen=window_size_samples)
databuffer = bytearray(b"")

try:
while websocket_active:
raw_data = await websocket.receive_bytes()
raw_data_recv_time = time.time()
data = raw_data[:]

if codec == 'opus' and sample_rate == 16000:
data = decoder.decode(bytes(data), frame_size=160)

# audio_buffer.extend(data)
# if len(audio_buffer) < window_size_bytes:
# continue

# speech_state = is_speech_present(audio_buffer[:window_size_bytes], vad_iterator, window_size_samples)

# if speech_state == SpeechState.no_speech:
# audio_buffer = audio_buffer[window_size_bytes:]
# continue
elif codec not in ['pcm8', 'pcm16']:
raise ValueError(f"Unsupported codec: {codec}")
samples = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 # Convert into range -1 to 1 earlier to decrease cpu usage later
audio_buffer.extend(samples)

# VAD processing
if len(audio_buffer) >= window_size_samples:
if must_send_data: # Send remaining batch databuffer (usecase for now)
must_send_data = False
pass
elif codec == 'opus' and raw_data_recv_time - last_speech_time < speech_timeout/2 and is_speech_active:
pass
elif codec == 'opus' and raw_data_recv_time - last_vad_check_time < opus_vad_check_interval_active and is_speech_active:
pass
elif codec== 'opus' and raw_data_recv_time - last_vad_check_time < opus_vad_check_interval_nonactive and not is_speech_active:
continue
else:
last_vad_check_time = time.time()
is_containing_speech = await asyncio.get_event_loop().run_in_executor(thread_pool, is_speech_present, list(audio_buffer), vad_iterator, window_size_samples)
if is_containing_speech == SpeechState.speech_found:
# print('+Detected speech at ' + str(raw_data_recv_time))
is_speech_active = True
last_speech_time = raw_data_recv_time + speech_timeout/2
elif is_speech_active:
if raw_data_recv_time - last_speech_time > speech_timeout:
is_speech_active = False
vad_iterator.reset_states()
must_send_data = True # To send remaining batch databuffer
start_speech_inactivity_time = time.time() # For soniox keepalive send anything
opus_vad_check_interval_nonactive = DEFAULT_VAD_CHECK_INTERVAL_NONACTIVE # reset vad inactive interval
print('-NO Detected speech')
continue
else:
if time.time() - start_speech_inactivity_time > soniox_sendanything_interval: # For soniox weird keepalive
start_speech_inactivity_time = time.time()
if soniox_socket is not None:
asyncio.create_task(soniox_socket.send(b''))
if dg_socket1 is not None:
asyncio.create_task(dg_socket1.send(b''))
if dg_socket2 is not None:
asyncio.create_task(dg_socket2.send(b''))
if opus_vad_check_interval_nonactive <= 0.15: # Increase interval slowly to decrease cpu usage so other tasks can run
opus_vad_check_interval_nonactive += 0.0005
continue
else:
continue

databuffer.extend(data)

if len(databuffer) < window_size_bytes and not must_send_data: #Batching to decrease io to stt
continue

if soniox_socket is not None:
await soniox_socket.send(data)
asyncio.create_task(soniox_socket.send(databuffer))

if speechmatics_socket1 is not None:
await speechmatics_socket1.send(data)
await speechmatics_socket1.send(databuffer)

if deepgram_socket is not None:
elapsed_seconds = time.time() - timer_start
if elapsed_seconds > duration or not dg_socket2:
dg_socket1.send(data)
asyncio.create_task(dg_socket1.send(databuffer))
if dg_socket2:
print('Killing socket2')
dg_socket2.finish()
dg_socket2 = None
else:
dg_socket2.send(data)
asyncio.create_task(dg_socket2.send(databuffer))

# audio_buffer = audio_buffer[window_size_bytes:]
databuffer = bytearray(b"")

except WebSocketDisconnect:
print("WebSocket disconnected")
except WebSocketDisconnect as e:
print(f"WebSocket disconnected: {e}")
except Exception as e:
print(f'Could not process audio: error {e}')
websocket_close_code = 1011
finally:
websocket_active = False
if dg_socket1:
dg_socket1.finish()
await dg_socket1.finish()
if dg_socket2:
dg_socket2.finish()
await dg_socket2.finish()
if soniox_socket:
await soniox_socket.close()
if speechmatics_socket:
Expand All @@ -299,7 +358,6 @@ async def send_heartbeat():
try:
while websocket_active:
await asyncio.sleep(30)
# print('send_heartbeat')
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
Expand Down
43 changes: 24 additions & 19 deletions backend/utils/stt/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def process_audio_dg(
):
print('process_audio_dg', language, sample_rate, channels, preseconds)

def on_message(self, result, **kwargs):
async def on_message(self, result, **kwargs):
# print(f"Received message from Deepgram") # Log when message is received
sentence = result.channel.alternatives[0].transcript
# print(sentence)
Expand Down Expand Up @@ -139,46 +139,46 @@ def on_message(self, result, **kwargs):
# stream
stream_transcript(segments, stream_id)

def on_error(self, error, **kwargs):
async def on_error(self, error, **kwargs):
print(f"Error: {error}")

print("Connecting to Deepgram") # Log before connection attempt
return connect_to_deepgram(on_message, on_error, language, sample_rate, channels)
return await connect_to_deepgram(on_message, on_error, language, sample_rate, channels)


def process_segments(uid: str, segments: list[dict]):
token = notification_db.get_token_only(uid) # TODO: don't retrieve token before knowing if to notify
trigger_realtime_integrations(uid, token, segments)


def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, channels: int):
async def connect_to_deepgram(on_message, on_error, language: str, sample_rate: int, channels: int):
# 'wss://api.deepgram.com/v1/listen?encoding=linear16&sample_rate=8000&language=$recordingsLanguage&model=nova-2-general&no_delay=true&endpointing=100&interim_results=false&smart_format=true&diarize=true'
try:
dg_connection = deepgram.listen.websocket.v("1")
dg_connection = deepgram.listen.asyncwebsocket.v("1")
dg_connection.on(LiveTranscriptionEvents.Transcript, on_message)
dg_connection.on(LiveTranscriptionEvents.Error, on_error)

def on_open(self, open, **kwargs):
async def on_open(self, open, **kwargs):
print("Connection Open")

def on_metadata(self, metadata, **kwargs):
async def on_metadata(self, metadata, **kwargs):
print(f"Metadata: {metadata}")

def on_speech_started(self, speech_started, **kwargs):
async def on_speech_started(self, speech_started, **kwargs):
print("Speech Started")

def on_utterance_end(self, utterance_end, **kwargs):
async def on_utterance_end(self, utterance_end, **kwargs):
print("Utterance End")
global is_finals
if len(is_finals) > 0:
utterance = " ".join(is_finals)
print(f"Utterance End: {utterance}")
is_finals = []

def on_close(self, close, **kwargs):
async def on_close(self, close, **kwargs):
print("Connection Closed")

def on_unhandled(self, unhandled, **kwargs):
async def on_unhandled(self, unhandled, **kwargs):
print(f"Unhandled Websocket Message: {unhandled}")

dg_connection.on(LiveTranscriptionEvents.Open, on_open)
Expand All @@ -203,7 +203,7 @@ def on_unhandled(self, unhandled, **kwargs):
sample_rate=sample_rate,
encoding='linear16'
)
result = dg_connection.start(options)
result = await dg_connection.start(options)
print('Deepgram connection started:', result)
return dg_connection
except Exception as e:
Expand Down Expand Up @@ -253,11 +253,17 @@ async def process_audio_soniox(stream_transcript, stream_id: int, sample_rate: i
# Send the initial request
await soniox_socket.send(json.dumps(request))
print(f"Sent initial request: {request}")


# Fuck latency and buffer length, this is how you manage (or unmanage lol) high traffic
soniox_socket.ping_timeout = None

# Start listening for messages from Soniox
async def on_message():
try:
async for message in soniox_socket:
while True:
message = await soniox_socket.recv()
if message is None:
continue
response = json.loads(message)
# print(response)
fw = response['fw']
Expand Down Expand Up @@ -308,11 +314,10 @@ async def on_message():
except Exception as e:
print(f"Error receiving from Soniox: {e}")
finally:
if not soniox_socket.closed:
await soniox_socket.close()
print("Soniox WebSocket closed in on_message.")

# Start the on_message coroutine
await soniox_socket.close()
print("Soniox WebSocket closed in on_message.")

# Start the on_message task
asyncio.create_task(on_message())

# Return the Soniox WebSocket object
Expand Down
12 changes: 6 additions & 6 deletions backend/utils/stt/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

torch.set_num_threads(1)
torch.hub.set_dir('pretrained_models')
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', onnx=True)
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils


Expand All @@ -19,12 +19,12 @@ class SpeechState(str, Enum):


def is_speech_present(data, vad_iterator, window_size_samples=256):
data_int16 = np.frombuffer(data, dtype=np.int16)
data_float32 = data_int16.astype(np.float32) / 32768.0
# data_int16 = np.frombuffer(data, dtype=np.int16)
# data_float32 = data_int16.astype(np.float32) / 32768.0
has_start, has_end = False, False

for i in range(0, len(data_float32), window_size_samples):
chunk = data_float32[i: i + window_size_samples]
for i in range(0, len(data), window_size_samples):
chunk = data[i: i + window_size_samples]
if len(chunk) < window_size_samples:
break
speech_dict = vad_iterator(chunk, return_seconds=False)
Expand All @@ -43,7 +43,7 @@ def is_speech_present(data, vad_iterator, window_size_samples=256):
# return SpeechState.speech_found
# elif has_end:
# return SpeechState.no_speech
vad_iterator.reset_states()
# vad_iterator.reset_states()
return SpeechState.no_speech


Expand Down