Skip to content

Commit

Permalink
Add APICredit & Package API, fix exception when no body is provided.
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Oct 7, 2024
1 parent 6d6ea0d commit 4f85d74
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/fish_audio_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .apis import Session
from .exceptions import HttpCodeErr
from .schemas import TTSRequest, ASRRequest
from .schemas import ASRRequest, TTSRequest

__all__ = ["Session", "HttpCodeErr", "TTSRequest", "ASRRequest"]
22 changes: 20 additions & 2 deletions src/fish_audio_sdk/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import ormsgpack

from .schemas import ASRRequest, ASRResponse, ModelEntity, PaginatedResponse, TTSRequest
from .io import RemoteCall, convert, convert_stream, G, GStream, Request
from .io import G, GStream, RemoteCall, Request, convert, convert_stream
from .schemas import (
APICreditEntity,
ASRRequest,
ASRResponse,
ModelEntity,
PackageEntity,
PaginatedResponse,
TTSRequest,
)


class Session(RemoteCall):
Expand Down Expand Up @@ -140,5 +148,15 @@ def update_model(
files=files,
)

@convert
def get_api_credit(this) -> G[APICreditEntity]:
response = yield Request(method="GET", url="/wallet/self/api-credit")
return APICreditEntity.model_validate(response.json())

@convert
def get_package(this) -> G[PackageEntity]:
response = yield Request(method="GET", url="/wallet/self/package")
return PackageEntity.model_validate(response.json())


filter_none = lambda d: {k: v for k, v in d.items() if v is not None}
15 changes: 11 additions & 4 deletions src/fish_audio_sdk/io.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import dataclasses
import typing
from http.client import responses as http_responses
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
TypeVar,
Concatenate,
Generator,
ParamSpec,
Generic,
Concatenate,
ParamSpec,
TypeVar,
)

import httpx
Expand Down Expand Up @@ -63,8 +64,14 @@ def _try_raise_http_exception(resp: httpx.Response) -> None:
if not resp.is_success:
try:
raise HttpCodeErr(**resp.json())
except httpx.ResponseNotRead:
raise HttpCodeErr(
status=resp.status_code, message=http_responses[resp.status_code]
)
except TypeError:
raise HttpCodeErr(status=resp.status_code, message=resp.json()["detail"])
raise HttpCodeErr(
status=resp.status_code, message=resp.json()["detail"]
)


P = ParamSpec("P")
Expand Down
22 changes: 21 additions & 1 deletion src/fish_audio_sdk/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Annotated, Literal, Generic, TypeVar
import decimal
from typing import Annotated, Generic, Literal, TypeVar

from pydantic import BaseModel, Field, conint

Expand Down Expand Up @@ -84,3 +85,22 @@ class ModelEntity(BaseModel):
marked: bool = False

author: AuthorEntity


class APICreditEntity(BaseModel):
_id: str
user_id: str
credit: decimal.Decimal
created_at: str
updated_at: str


class PackageEntity(BaseModel):
_id: str
user_id: str
type: str
total: int
balance: int
created_at: str
updated_at: str
finished_at: str
31 changes: 27 additions & 4 deletions tests/test_apis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from fish_audio_sdk import Session, HttpCodeErr, TTSRequest, ASRRequest
import pytest

from fish_audio_sdk import ASRRequest, HttpCodeErr, Session, TTSRequest
from fish_audio_sdk.schemas import APICreditEntity, PackageEntity


def test_tts(session: Session):
Expand Down Expand Up @@ -47,7 +50,27 @@ def test_get_model(session: Session):


def test_get_model_not_found(session: Session):
try:
with pytest.raises(HttpCodeErr) as exc_info:
session.get_model(model_id="123")
except HttpCodeErr as e:
assert e.status == 404
assert exc_info.value.status == 404


def test_invalid_token(session: Session):
session._apikey = "invalid"
session.init_async_client()
session.init_sync_client()

with pytest.raises(HttpCodeErr) as exc_info:
test_tts(session)

assert exc_info.value.status in [401, 402]


def test_get_api_credit(session: Session):
res = session.get_api_credit()
assert isinstance(res, APICreditEntity)


def test_get_package(session: Session):
res = session.get_package()
assert isinstance(res, PackageEntity)

0 comments on commit 4f85d74

Please sign in to comment.