Skip to content

Commit

Permalink
Add new category for memories, secure v3 listen API, deprecated the v… (
Browse files Browse the repository at this point in the history
#1724)

…2 listen API
  • Loading branch information
beastoin authored Jan 23, 2025
2 parents 6ad2272 + 171cac9 commit cd5164a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
6 changes: 6 additions & 0 deletions backend/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from ._client import db, document_id_from_seed


def is_exists_user(uid: str):
user_ref = db.collection('users').document(uid)
if not user_ref.get().exists:
return False
return True

def get_user_store_recording_permission(uid: str):
user_ref = db.collection('users').document(uid)
user_data = user_ref.get().to_dict()
Expand Down
11 changes: 11 additions & 0 deletions backend/models/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ class CategoryEnum(str, Enum):
literature = 'literature'
history = 'history'
architecture = 'architecture'
# Added at 2024-01-23
music = 'music'
weather = 'weather'
news = 'news'
entertainment = 'entertainment'
psychology = 'psychology'
real = 'real'
design = 'design'
family = 'family'
economics = 'economics'
environment = 'environment'
other = 'other'


Expand Down
27 changes: 16 additions & 11 deletions backend/routers/transcribe_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starlette.websockets import WebSocketState

import database.memories as memories_db
import database.users as user_db
from database import redis_db
from database.redis_db import get_cached_user_geolocation
from models.memory import Memory, TranscriptSegment, MemoryStatus, Structured, Geolocation
Expand Down Expand Up @@ -70,7 +71,13 @@ async def _websocket_util(
print('_websocket_util', uid, language, sample_rate, codec, include_speech_profile)

if not uid or len(uid) <= 0:
raise HTTPException(status_code=400, detail="Invalid UID")
await websocket.close(code=1008, reason="Bad uid")
return

# Validate user
if not user_db.is_exists_user(uid):
await websocket.close(code=1008, reason="Bad user")
return

# Not when comes from the phone, and only Friend's with 1.0.4
# if stt_service == STTService.soniox and language not in soniox_valid_languages:
Expand Down Expand Up @@ -439,7 +446,6 @@ async def close(code: int = 1000):

pusher_connect, pusher_close, transcript_send, transcript_consume, audio_bytes_send, audio_bytes_consume = create_pusher_task_handler()


current_memory_id = None

async def stream_transcript_process():
Expand Down Expand Up @@ -493,7 +499,6 @@ async def stream_transcript_process():
memories_db.update_memory_segments(uid, memory.id, [s.dict() for s in memory.transcript_segments])
memories_db.update_memory_finished_at(uid, memory.id, finished_at)


# threading.Thread(target=process_segments, args=(uid, segments)).start() # restore when plugins work
except Exception as e:
print(f'Could not process transcript: error {e}', uid)
Expand Down Expand Up @@ -618,16 +623,16 @@ async def send_heartbeat():
except Exception as e:
print(f"Error closing Pusher: {e}", uid)


@router.websocket("/v2/listen")
async def websocket_endpoint(
websocket: WebSocket, uid: str, language: str = 'en', sample_rate: int = 8000, codec: str = 'pcm8',
channels: int = 1, include_speech_profile: bool = True, stt_service: STTService = STTService.soniox
):
await _websocket_util(websocket, uid, language, sample_rate, codec, channels, include_speech_profile, stt_service)
# @deprecated
# @router.websocket("/v2/listen")
# async def websocket_endpoint_v2(
# websocket: WebSocket, uid: str, language: str = 'en', sample_rate: int = 8000, codec: str = 'pcm8',
# channels: int = 1, include_speech_profile: bool = True, stt_service: STTService = STTService.soniox
# ):
# await _websocket_util(websocket, uid, language, sample_rate, codec, channels, include_speech_profile, stt_service)

@router.websocket("/v3/listen")
async def websocket_endpoint_v3(
async def websocket_endpoint(
websocket: WebSocket, uid: str = Depends(auth.get_current_user_uid), language: str = 'en', sample_rate: int = 8000, codec: str = 'pcm8',
channels: int = 1, include_speech_profile: bool = True, stt_service: STTService = STTService.soniox
):
Expand Down
4 changes: 2 additions & 2 deletions backend/utils/other/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_current_user_uid(authorization: str = Header(None)):
if os.getenv('ADMIN_KEY') in authorization:
if authorization and os.getenv('ADMIN_KEY') in authorization:
return authorization.split(os.getenv('ADMIN_KEY'))[1]

if not authorization:
Expand Down Expand Up @@ -90,4 +90,4 @@ def measure_time(*args, **kw):

def delete_account(uid: str):
auth.delete_user(uid)
return {"message": "User deleted"}
return {"message": "User deleted"}

0 comments on commit cd5164a

Please sign in to comment.