Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add APICredit & Package API, fix exception when no body is provided. #1

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fish_audio_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .apis import Session
from .exceptions import HttpCodeErr
from .schemas import TTSRequest, ASRRequest
from .schemas import ASRRequest, TTSRequest

__all__ = ["Session", "HttpCodeErr", "TTSRequest", "ASRRequest"]
22 changes: 20 additions & 2 deletions src/fish_audio_sdk/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import ormsgpack

from .schemas import ASRRequest, ASRResponse, ModelEntity, PaginatedResponse, TTSRequest
from .io import RemoteCall, convert, convert_stream, G, GStream, Request
from .io import G, GStream, RemoteCall, Request, convert, convert_stream
from .schemas import (
APICreditEntity,
ASRRequest,
ASRResponse,
ModelEntity,
PackageEntity,
PaginatedResponse,
TTSRequest,
)


class Session(RemoteCall):
Expand Down Expand Up @@ -140,5 +148,15 @@ def update_model(
files=files,
)

@convert
def get_api_credit(this) -> G[APICreditEntity]:
response = yield Request(method="GET", url="/wallet/self/api-credit")
return APICreditEntity.model_validate(response.json())

@convert
def get_package(this) -> G[PackageEntity]:
response = yield Request(method="GET", url="/wallet/self/package")
return PackageEntity.model_validate(response.json())


filter_none = lambda d: {k: v for k, v in d.items() if v is not None}
15 changes: 11 additions & 4 deletions src/fish_audio_sdk/io.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import dataclasses
import typing
from http.client import responses as http_responses
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
TypeVar,
Concatenate,
Generator,
ParamSpec,
Generic,
Concatenate,
ParamSpec,
TypeVar,
)

import httpx
Expand Down Expand Up @@ -63,8 +64,14 @@ def _try_raise_http_exception(resp: httpx.Response) -> None:
if not resp.is_success:
try:
raise HttpCodeErr(**resp.json())
except httpx.ResponseNotRead:
raise HttpCodeErr(
status=resp.status_code, message=http_responses[resp.status_code]
)
except TypeError:
raise HttpCodeErr(status=resp.status_code, message=resp.json()["detail"])
raise HttpCodeErr(
status=resp.status_code, message=resp.json()["detail"]
)


P = ParamSpec("P")
Expand Down
22 changes: 21 additions & 1 deletion src/fish_audio_sdk/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Annotated, Literal, Generic, TypeVar
import decimal
from typing import Annotated, Generic, Literal, TypeVar

from pydantic import BaseModel, Field, conint

Expand Down Expand Up @@ -84,3 +85,22 @@ class ModelEntity(BaseModel):
marked: bool = False

author: AuthorEntity


class APICreditEntity(BaseModel):
_id: str
user_id: str
credit: decimal.Decimal
created_at: str
updated_at: str


class PackageEntity(BaseModel):
_id: str
user_id: str
type: str
total: int
balance: int
created_at: str
updated_at: str
finished_at: str
31 changes: 27 additions & 4 deletions tests/test_apis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from fish_audio_sdk import Session, HttpCodeErr, TTSRequest, ASRRequest
import pytest

from fish_audio_sdk import ASRRequest, HttpCodeErr, Session, TTSRequest
from fish_audio_sdk.schemas import APICreditEntity, PackageEntity


def test_tts(session: Session):
Expand Down Expand Up @@ -47,7 +50,27 @@ def test_get_model(session: Session):


def test_get_model_not_found(session: Session):
try:
with pytest.raises(HttpCodeErr) as exc_info:
session.get_model(model_id="123")
except HttpCodeErr as e:
assert e.status == 404
assert exc_info.value.status == 404


def test_invalid_token(session: Session):
session._apikey = "invalid"
session.init_async_client()
session.init_sync_client()

with pytest.raises(HttpCodeErr) as exc_info:
test_tts(session)

assert exc_info.value.status in [401, 402]


def test_get_api_credit(session: Session):
res = session.get_api_credit()
assert isinstance(res, APICreditEntity)


def test_get_package(session: Session):
res = session.get_package()
assert isinstance(res, PackageEntity)