Skip to content

Commit

Permalink
refactor assistant streaming and create OpenAI compliant base class (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored May 28, 2024
1 parent 84cf4f6 commit a45bd90
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 173 deletions.
2 changes: 2 additions & 0 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# - [OpenAI](https://openai.com/)
# - [ragna.assistants.Gpt35Turbo16k][]
# - [ragna.assistants.Gpt4][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]

from ragna import assistants

Expand Down
4 changes: 3 additions & 1 deletion docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@
# - [ragna.assistants.Gpt4][]
# - [AI21 Labs](https://www.ai21.com/)
# - [ragna.assistants.Jurassic2Ultra][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
#
# !!! note
#
# To use any of builtin assistants, you need to
# To use some of the builtin assistants, you need to
# [procure API keys](../../references/faq.md#where-do-i-get-api-keys-for-the-builtin-assistants)
# first and set the corresponding environment variables.

Expand Down
2 changes: 2 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"Gpt35Turbo16k",
"Gpt4",
"Jurassic2Ultra",
"LlamafileAssistant",
"RagnaDemoAssistant",
]

Expand All @@ -17,6 +18,7 @@
from ._cohere import Command, CommandLight
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._llamafile import LlamafileAssistant
from ._openai import Gpt4, Gpt35Turbo16k

# isort: split
Expand Down
8 changes: 4 additions & 4 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ragna.core import Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


class Ai21LabsAssistant(ApiAssistant):
class Ai21LabsAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_MODEL_TYPE: str

Expand All @@ -21,8 +21,8 @@ def _make_system_content(self, sources: list[Source]) -> str:
)
return instruction + "\n\n".join(source.content for source in sources)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
Expand Down
42 changes: 17 additions & 25 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
from typing import AsyncIterator, cast

from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


class AnthropicApiAssistant(ApiAssistant):
class AnthropicAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_MODEL: str

Expand Down Expand Up @@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
+ "</documents>"
)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import httpx_sse

# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
async for data in self._stream_sse(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand All @@ -61,23 +57,19 @@ async def _call_api(
"temperature": 0.0,
"stream": True,
},
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue
):
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))
yield cast(str, data["delta"].pop("text"))


class ClaudeOpus(AnthropicApiAssistant):
class ClaudeOpus(AnthropicAssistant):
"""[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview)
!!! info "Required environment variables"
Expand All @@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant):
_MODEL = "claude-3-opus-20240229"


class ClaudeSonnet(AnthropicApiAssistant):
class ClaudeSonnet(AnthropicAssistant):
"""[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview)
!!! info "Required environment variables"
Expand All @@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant):
_MODEL = "claude-3-sonnet-20240229"


class ClaudeHaiku(AnthropicApiAssistant):
class ClaudeHaiku(AnthropicAssistant):
"""[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview)
!!! info "Required environment variables"
Expand Down
59 changes: 0 additions & 59 deletions ragna/assistants/_api.py

This file was deleted.

33 changes: 14 additions & 19 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


class CohereApiAssistant(ApiAssistant):
class CohereAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_MODEL: str

Expand All @@ -24,13 +23,13 @@ def _make_preamble(self) -> str:
def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async with self._client.stream(
async for event in self._stream_jsonl(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand All @@ -47,21 +46,17 @@ async def _call_api(
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
},
) as response:
await self._assert_api_call_is_success(response)
):
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

async for chunk in response.aiter_lines():
event = json.loads(chunk)
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break
raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])


class Command(CohereApiAssistant):
class Command(CohereAssistant):
"""
[Cohere Command](https://docs.cohere.com/docs/models#command)
Expand All @@ -73,7 +68,7 @@ class Command(CohereApiAssistant):
_MODEL = "command"


class CommandLight(CohereApiAssistant):
class CommandLight(CohereAssistant):
"""
[Cohere Command-Light](https://docs.cohere.com/docs/models#command)
Expand Down
12 changes: 6 additions & 6 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant
from ._http_api import HttpApiAssistant


# ijson does not support reading from an (async) iterator, but only from file-like
Expand All @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes:
return await anext(self._ait, b"") # type: ignore[call-arg]


class GoogleApiAssistant(ApiAssistant):
class GoogleAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str

Expand All @@ -48,8 +48,8 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
]
)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import ijson

Expand Down Expand Up @@ -88,7 +88,7 @@ async def _call_api(
yield chunk


class GeminiPro(GoogleApiAssistant):
class GeminiPro(GoogleAssistant):
"""[Google Gemini Pro](https://ai.google.dev/models/gemini)
!!! info "Required environment variables"
Expand All @@ -103,7 +103,7 @@ class GeminiPro(GoogleApiAssistant):
_MODEL = "gemini-pro"


class GeminiUltra(GoogleApiAssistant):
class GeminiUltra(GoogleAssistant):
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini)
!!! info "Required environment variables"
Expand Down
Loading

0 comments on commit a45bd90

Please sign in to comment.