-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor assistant streaming and create OpenAI compliant base class #425
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
import json | ||
from typing import AsyncIterator, cast | ||
|
||
from ragna.core import PackageRequirement, RagnaException, Requirement, Source | ||
|
||
from ._api import ApiAssistant | ||
from ._http_api import HttpApiAssistant | ||
|
||
|
||
class AnthropicApiAssistant(ApiAssistant): | ||
class AnthropicAssistant(HttpApiAssistant): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Driveby rename to align it with other provider base classes. |
||
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" | ||
_MODEL: str | ||
|
||
|
@@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: | |
+ "</documents>" | ||
) | ||
|
||
async def _call_api( | ||
self, prompt: str, sources: list[Source], *, max_new_tokens: int | ||
async def answer( | ||
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 | ||
) -> AsyncIterator[str]: | ||
import httpx_sse | ||
|
||
# See https://docs.anthropic.com/claude/reference/messages_post | ||
# See https://docs.anthropic.com/claude/reference/streaming | ||
async with httpx_sse.aconnect_sse( | ||
self._client, | ||
async for data in self._stream_sse( | ||
"POST", | ||
"https://api.anthropic.com/v1/messages", | ||
headers={ | ||
|
@@ -61,23 +57,19 @@ async def _call_api( | |
"temperature": 0.0, | ||
"stream": True, | ||
}, | ||
) as event_source: | ||
await self._assert_api_call_is_success(event_source.response) | ||
|
||
async for sse in event_source.aiter_sse(): | ||
data = json.loads(sse.data) | ||
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response | ||
if "error" in data: | ||
raise RagnaException(data["error"].pop("message"), **data["error"]) | ||
elif data["type"] == "message_stop": | ||
break | ||
elif data["type"] != "content_block_delta": | ||
continue | ||
): | ||
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response | ||
if "error" in data: | ||
raise RagnaException(data["error"].pop("message"), **data["error"]) | ||
elif data["type"] == "message_stop": | ||
break | ||
elif data["type"] != "content_block_delta": | ||
continue | ||
|
||
yield cast(str, data["delta"].pop("text")) | ||
yield cast(str, data["delta"].pop("text")) | ||
|
||
|
||
class ClaudeOpus(AnthropicApiAssistant): | ||
class ClaudeOpus(AnthropicAssistant): | ||
"""[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview) | ||
|
||
!!! info "Required environment variables" | ||
|
@@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant): | |
_MODEL = "claude-3-opus-20240229" | ||
|
||
|
||
class ClaudeSonnet(AnthropicApiAssistant): | ||
class ClaudeSonnet(AnthropicAssistant): | ||
"""[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview) | ||
|
||
!!! info "Required environment variables" | ||
|
@@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant): | |
_MODEL = "claude-3-sonnet-20240229" | ||
|
||
|
||
class ClaudeHaiku(AnthropicApiAssistant): | ||
class ClaudeHaiku(AnthropicAssistant): | ||
"""[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview) | ||
|
||
!!! info "Required environment variables" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file was renamed to |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import contextlib | ||
import json | ||
import os | ||
from typing import Any, AsyncIterator, Optional | ||
|
||
import httpx | ||
from httpx import Response | ||
|
||
import ragna | ||
from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement | ||
|
||
|
||
class HttpApiAssistant(Assistant): | ||
_API_KEY_ENV_VAR: Optional[str] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The API key is now optional. See #375 for a discussion. |
||
|
||
@classmethod | ||
def requirements(cls) -> list[Requirement]: | ||
requirements: list[Requirement] = ( | ||
[EnvVarRequirement(cls._API_KEY_ENV_VAR)] | ||
if cls._API_KEY_ENV_VAR is not None | ||
else [] | ||
) | ||
requirements.extend(cls._extra_requirements()) | ||
return requirements | ||
|
||
@classmethod | ||
def _extra_requirements(cls) -> list[Requirement]: | ||
return [] | ||
|
||
def __init__(self) -> None: | ||
self._client = httpx.AsyncClient( | ||
headers={"User-Agent": f"{ragna.__version__}/{self}"}, | ||
timeout=60, | ||
) | ||
self._api_key: Optional[str] = ( | ||
os.environ[self._API_KEY_ENV_VAR] | ||
if self._API_KEY_ENV_VAR is not None | ||
else None | ||
) | ||
|
||
async def _assert_api_call_is_success(self, response: Response) -> None: | ||
if response.is_success: | ||
return | ||
|
||
content = await response.aread() | ||
with contextlib.suppress(Exception): | ||
content = json.loads(content) | ||
|
||
raise RagnaException( | ||
"API call failed", | ||
request_method=response.request.method, | ||
request_url=str(response.request.url), | ||
response_status_code=response.status_code, | ||
response_content=content, | ||
) | ||
|
||
async def _stream_sse( | ||
self, | ||
method: str, | ||
url: str, | ||
**kwargs: Any, | ||
) -> AsyncIterator[dict[str, Any]]: | ||
import httpx_sse | ||
|
||
async with httpx_sse.aconnect_sse( | ||
self._client, method, url, **kwargs | ||
) as event_source: | ||
await self._assert_api_call_is_success(event_source.response) | ||
|
||
async for sse in event_source.aiter_sse(): | ||
yield json.loads(sse.data) | ||
|
||
async def _stream_jsonl( | ||
self, method: str, url: str, **kwargs: Any | ||
) -> AsyncIterator[dict[str, Any]]: | ||
async with self._client.stream(method, url, **kwargs) as response: | ||
await self._assert_api_call_is_success(response) | ||
|
||
async for chunk in response.aiter_lines(): | ||
yield json.loads(chunk) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a demonstration how easy it is after this PR to add new OpenAI compliant assistants. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os | ||
|
||
from ._openai import OpenaiCompliantHttpApiAssistant | ||
|
||
|
||
class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): | ||
_API_KEY_ENV_VAR = None | ||
_STREAMING_METHOD = "sse" | ||
_MODEL = None | ||
|
||
@property | ||
def _url(self) -> str: | ||
base_url = os.environ.get("RAGNA_LLAMAFILE_BASE_URL", "http://localhost:8080") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nenb is port 8080 the default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be the case. |
||
return f"{base_url}/v1/chat/completions" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
def _call_api
abstraction fordef answer
was just a remnant of an old implementation that I forgot to clean up earlier:ragna/ragna/assistants/_api.py
Lines 32 to 38 in 84cf4f6
This PR removes it and all subclass simply implement
def answer
directly.