Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocket API #2

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"httpx>=0.27.2",
"ormsgpack>=1.5.0",
"pydantic>=2.9.1",
"httpx-ws>=0.6.2",
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down
11 changes: 10 additions & 1 deletion src/fish_audio_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from .apis import Session
from .exceptions import HttpCodeErr
from .schemas import ASRRequest, TTSRequest, ReferenceAudio
from .websocket import WebSocketSession, AsyncWebSocketSession

__all__ = ["Session", "HttpCodeErr", "ReferenceAudio", "TTSRequest", "ASRRequest"]
__all__ = [
"Session",
"HttpCodeErr",
"ReferenceAudio",
"TTSRequest",
"ASRRequest",
"WebSocketSession",
"AsyncWebSocketSession",
]
6 changes: 6 additions & 0 deletions src/fish_audio_sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ def __init__(self, status: int, message: str):
self.status = status
self.message = message
super().__init__(f"{status} {message}")


class WebSocketErr(Exception):
"""
{"event": "finish", "reason": "error"} or WebSocketDisconnect
"""
14 changes: 14 additions & 0 deletions src/fish_audio_sdk/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,17 @@ class PackageEntity(BaseModel):
created_at: str
updated_at: str
finished_at: str


class StartEvent(BaseModel):
event: Literal["start"] = "start"
request: TTSRequest


class TextEvent(BaseModel):
event: Literal["text"] = "text"
text: str


class CloseEvent(BaseModel):
event: Literal["stop"] = "stop"
147 changes: 147 additions & 0 deletions src/fish_audio_sdk/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncGenerator, AsyncIterable, Generator, Iterable

import httpx
import ormsgpack
from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws

from .exceptions import WebSocketErr

from .schemas import CloseEvent, StartEvent, TTSRequest, TextEvent


class WebSocketSession:
def __init__(
self,
apikey: str,
*,
base_url: str = "https://api.fish.audio",
max_workers: int = 10,
):
self._apikey = apikey
self._base_url = base_url
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def close(self):
self._client.close()

def tts(
self, request: TTSRequest, text_stream: Iterable[str]
) -> Generator[bytes, None, None]:
with connect_ws("/v1/tts/live", client=self._client) as ws:

def sender():
ws.send_bytes(
ormsgpack.packb(
StartEvent(request=request),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
for text in text_stream:
ws.send_bytes(
ormsgpack.packb(
TextEvent(text=text),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
ws.send_bytes(
ormsgpack.packb(
CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
)
)

sender_future = self._executor.submit(sender)

while True:
try:
message = ws.receive_bytes()
data = ormsgpack.unpackb(message)
match data["event"]:
case "audio":
yield data["audio"]
case "finish" if data["reason"] == "error":
raise WebSocketErr
case "finish" if data["reason"] == "stop":
break
except WebSocketDisconnect:
raise WebSocketErr

sender_future.result()


class AsyncWebSocketSession:
def __init__(
self,
apikey: str,
*,
base_url: str = "https://api.fish.audio",
):
self._apikey = apikey
self._base_url = base_url
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()

async def close(self):
await self._client.aclose()

async def tts(
self, request: TTSRequest, text_stream: AsyncIterable[str]
) -> AsyncGenerator[bytes, None]:
async with aconnect_ws("/v1/tts/live", client=self._client) as ws:

async def sender():
await ws.send_bytes(
ormsgpack.packb(
StartEvent(request=request),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
async for text in text_stream:
await ws.send_bytes(
ormsgpack.packb(
TextEvent(text=text),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
await ws.send_bytes(
ormsgpack.packb(
CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
)
)

sender_future = asyncio.get_running_loop().create_task(sender())

while True:
try:
message = await ws.receive_bytes()
data = ormsgpack.unpackb(message)
match data["event"]:
case "audio":
yield data["audio"]
case "finish" if data["reason"] == "error":
raise WebSocketErr
case "finish" if data["reason"] == "stop":
break
except WebSocketDisconnect:
raise WebSocketErr

await sender_future
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@

import pytest

from fish_audio_sdk.apis import Session
from fish_audio_sdk import Session, WebSocketSession, AsyncWebSocketSession

APIKEY = os.environ["APIKEY"]


@pytest.fixture
def session():
return Session(APIKEY)


@pytest.fixture
def sync_websocket():
return WebSocketSession(APIKEY)


@pytest.fixture
def async_websocket():
return AsyncWebSocketSession(APIKEY)
31 changes: 31 additions & 0 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fish_audio_sdk import TTSRequest, WebSocketSession, AsyncWebSocketSession

story = """
修炼了六千三百七十九年又三月零六天后,天门因她终于洞开。

她凭虚站立在黄山峰顶,因天门洞开而鼓起的飓风不停拍打着她身上的黑袍,在催促她快快登仙而去;黄山间壮阔的云海也随之翻涌,为这一场天地幸事欢呼雀跃。她没有抬头看向那似隐似现、若有若无、形态万千变化的天门,只是呆立在原处自顾自地看向远方。
"""


def test_tts(sync_websocket: WebSocketSession):
buffer = bytearray()

def stream():
for line in story.split("\n"):
yield line

for chunk in sync_websocket.tts(TTSRequest(text=""), stream()):
buffer.extend(chunk)
assert len(buffer) > 0


async def test_async_tts(async_websocket: AsyncWebSocketSession):
buffer = bytearray()

async def stream():
for line in story.split("\n"):
yield line

async for chunk in async_websocket.tts(TTSRequest(text=""), stream()):
buffer.extend(chunk)
assert len(buffer) > 0