Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor assistant streaming and create OpenAI compliant base class #425

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# - [OpenAI](https://openai.com/)
# - [ragna.assistants.Gpt35Turbo16k][]
# - [ragna.assistants.Gpt4][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]

from ragna import assistants

Expand Down
4 changes: 3 additions & 1 deletion docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@
# - [ragna.assistants.Gpt4][]
# - [AI21 Labs](https://www.ai21.com/)
# - [ragna.assistants.Jurassic2Ultra][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
#
# !!! note
#
# To use any of builtin assistants, you need to
# To use some of the builtin assistants, you need to
# [procure API keys](../../references/faq.md#where-do-i-get-api-keys-for-the-builtin-assistants)
# first and set the corresponding environment variables.

Expand Down
2 changes: 2 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"Gpt35Turbo16k",
"Gpt4",
"Jurassic2Ultra",
"LlamafileAssistant",
"RagnaDemoAssistant",
]

Expand All @@ -17,6 +18,7 @@
from ._cohere import Command, CommandLight
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._llamafile import LlamafileAssistant
from ._openai import Gpt4, Gpt35Turbo16k

# isort: split
Expand Down
8 changes: 4 additions & 4 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ragna.core import Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


class Ai21LabsAssistant(ApiAssistant):
class Ai21LabsAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_MODEL_TYPE: str

Expand All @@ -21,8 +21,8 @@ def _make_system_content(self, sources: list[Source]) -> str:
)
return instruction + "\n\n".join(source.content for source in sources)

async def _call_api(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The def _call_api abstraction for def answer was just a remnant of an old implementation that I forgot to clean up earlier:

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
async for chunk in self._call_api(
prompt, sources, max_new_tokens=max_new_tokens
):
yield chunk

This PR removes it and all subclass simply implement def answer directly.

self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
Expand Down
42 changes: 17 additions & 25 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
from typing import AsyncIterator, cast

from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


class AnthropicApiAssistant(ApiAssistant):
class AnthropicAssistant(HttpApiAssistant):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby rename to align it with other provider base classes.

_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_MODEL: str

Expand Down Expand Up @@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
+ "</documents>"
)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import httpx_sse

# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
async for data in self._stream_sse(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand All @@ -61,23 +57,19 @@ async def _call_api(
"temperature": 0.0,
"stream": True,
},
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue
):
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))
yield cast(str, data["delta"].pop("text"))


class ClaudeOpus(AnthropicApiAssistant):
class ClaudeOpus(AnthropicAssistant):
"""[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview)

!!! info "Required environment variables"
Expand All @@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant):
_MODEL = "claude-3-opus-20240229"


class ClaudeSonnet(AnthropicApiAssistant):
class ClaudeSonnet(AnthropicAssistant):
"""[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview)

!!! info "Required environment variables"
Expand All @@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant):
_MODEL = "claude-3-sonnet-20240229"


class ClaudeHaiku(AnthropicApiAssistant):
class ClaudeHaiku(AnthropicAssistant):
"""[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview)

!!! info "Required environment variables"
Expand Down
59 changes: 0 additions & 59 deletions ragna/assistants/_api.py

This file was deleted.

33 changes: 14 additions & 19 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


class CohereApiAssistant(ApiAssistant):
class CohereAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_MODEL: str

Expand All @@ -24,13 +23,13 @@ def _make_preamble(self) -> str:
def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async with self._client.stream(
async for event in self._stream_jsonl(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand All @@ -47,21 +46,17 @@ async def _call_api(
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
},
) as response:
await self._assert_api_call_is_success(response)
):
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

async for chunk in response.aiter_lines():
event = json.loads(chunk)
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break
raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])


class Command(CohereApiAssistant):
class Command(CohereAssistant):
"""
[Cohere Command](https://docs.cohere.com/docs/models#command)

Expand All @@ -73,7 +68,7 @@ class Command(CohereApiAssistant):
_MODEL = "command"


class CommandLight(CohereApiAssistant):
class CommandLight(CohereAssistant):
"""
[Cohere Command-Light](https://docs.cohere.com/docs/models#command)

Expand Down
12 changes: 6 additions & 6 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


# ijson does not support reading from an (async) iterator, but only from file-like
Expand All @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes:
return await anext(self._ait, b"") # type: ignore[call-arg]


class GoogleApiAssistant(ApiAssistant):
class GoogleAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str

Expand All @@ -48,8 +48,8 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
]
)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import ijson

Expand Down Expand Up @@ -88,7 +88,7 @@ async def _call_api(
yield chunk


class GeminiPro(GoogleApiAssistant):
class GeminiPro(GoogleAssistant):
"""[Google Gemini Pro](https://ai.google.dev/models/gemini)

!!! info "Required environment variables"
Expand All @@ -103,7 +103,7 @@ class GeminiPro(GoogleApiAssistant):
_MODEL = "gemini-pro"


class GeminiUltra(GoogleApiAssistant):
class GeminiUltra(GoogleAssistant):
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini)

!!! info "Required environment variables"
Expand Down
Loading