diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 6ffda30a..84d92d08 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -29,6 +29,8 @@ # - [OpenAI](https://openai.com/) # - [ragna.assistants.Gpt35Turbo16k][] # - [ragna.assistants.Gpt4][] +# - [llamafile](https://github.com/Mozilla-Ocho/llamafile) +# - [ragna.assistants.LlamafileAssistant][] from ragna import assistants diff --git a/docs/tutorials/gallery_python_api.py b/docs/tutorials/gallery_python_api.py index d1919bd3..a7d0ef63 100644 --- a/docs/tutorials/gallery_python_api.py +++ b/docs/tutorials/gallery_python_api.py @@ -85,10 +85,12 @@ # - [ragna.assistants.Gpt4][] # - [AI21 Labs](https://www.ai21.com/) # - [ragna.assistants.Jurassic2Ultra][] +# - [llamafile](https://github.com/Mozilla-Ocho/llamafile) +# - [ragna.assistants.LlamafileAssistant][] # # !!! note # -# To use any of builtin assistants, you need to +# To use some of the builtin assistants, you need to # [procure API keys](../../references/faq.md#where-do-i-get-api-keys-for-the-builtin-assistants) # first and set the corresponding environment variables. diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index 823d87ac..d583e7a0 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -9,6 +9,7 @@ "Gpt35Turbo16k", "Gpt4", "Jurassic2Ultra", + "LlamafileAssistant", "RagnaDemoAssistant", ] @@ -17,6 +18,7 @@ from ._cohere import Command, CommandLight from ._demo import RagnaDemoAssistant from ._google import GeminiPro, GeminiUltra +from ._llamafile import LlamafileAssistant from ._openai import Gpt4, Gpt35Turbo16k # isort: split diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 19cfd59b..1c61a213 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -2,10 +2,10 @@ from ragna.core import Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class Ai21LabsAssistant(ApiAssistant): +class Ai21LabsAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" _MODEL_TYPE: str @@ -21,8 +21,8 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index fa8922fe..37f132b5 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,12 +1,11 @@ -import json from typing import AsyncIterator, cast from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class AnthropicApiAssistant(ApiAssistant): +class AnthropicAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" _MODEL: str @@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import httpx_sse - # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming - async with httpx_sse.aconnect_sse( - self._client, + async for data in self._stream_sse( "POST", "https://api.anthropic.com/v1/messages", headers={ @@ -61,23 +57,19 @@ async def _call_api( "temperature": 0.0, "stream": True, }, - ) as event_source: - await self._assert_api_call_is_success(event_source.response) - - async for sse in event_source.aiter_sse(): - data = json.loads(sse.data) - # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response - if "error" in data: - raise RagnaException(data["error"].pop("message"), **data["error"]) - elif data["type"] == "message_stop": - break - elif data["type"] != "content_block_delta": - continue + ): + # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response + if "error" in data: + raise RagnaException(data["error"].pop("message"), **data["error"]) + elif data["type"] == "message_stop": + break + elif data["type"] != "content_block_delta": + continue - yield cast(str, data["delta"].pop("text")) + yield cast(str, data["delta"].pop("text")) -class ClaudeOpus(AnthropicApiAssistant): +class ClaudeOpus(AnthropicAssistant): """[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview) !!! info "Required environment variables" @@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant): _MODEL = "claude-3-opus-20240229" -class ClaudeSonnet(AnthropicApiAssistant): +class ClaudeSonnet(AnthropicAssistant): """[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview) !!! info "Required environment variables" @@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant): _MODEL = "claude-3-sonnet-20240229" -class ClaudeHaiku(AnthropicApiAssistant): +class ClaudeHaiku(AnthropicAssistant): """[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview) !!! info "Required environment variables" diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py deleted file mode 100644 index 446d9545..00000000 --- a/ragna/assistants/_api.py +++ /dev/null @@ -1,59 +0,0 @@ -import abc -import contextlib -import json -import os -from typing import AsyncIterator - -import httpx -from httpx import Response - -import ragna -from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement, Source - - -class ApiAssistant(Assistant): - _API_KEY_ENV_VAR: str - - @classmethod - def requirements(cls) -> list[Requirement]: - return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] - - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [] - - def __init__(self) -> None: - self._client = httpx.AsyncClient( - headers={"User-Agent": f"{ragna.__version__}/{self}"}, - timeout=60, - ) - self._api_key = os.environ[self._API_KEY_ENV_VAR] - - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: - async for chunk in self._call_api( - prompt, sources, max_new_tokens=max_new_tokens - ): - yield chunk - - @abc.abstractmethod - def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> AsyncIterator[str]: ... - - async def _assert_api_call_is_success(self, response: Response) -> None: - if response.is_success: - return - - content = await response.aread() - with contextlib.suppress(Exception): - content = json.loads(content) - - raise RagnaException( - "API call failed", - request_method=response.request.method, - request_url=str(response.request.url), - response_status_code=response.status_code, - response_content=content, - ) diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index a93a264f..b47737f8 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,12 +1,11 @@ -import json from typing import AsyncIterator, cast from ragna.core import RagnaException, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class CohereApiAssistant(ApiAssistant): +class CohereAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" _MODEL: str @@ -24,13 +23,13 @@ def _make_preamble(self) -> str: def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - async with self._client.stream( + async for event in self._stream_jsonl( "POST", "https://api.cohere.ai/v1/chat", headers={ @@ -47,21 +46,17 @@ async def _call_api( "max_tokens": max_new_tokens, "documents": self._make_source_documents(sources), }, - ) as response: - await self._assert_api_call_is_success(response) + ): + if event["event_type"] == "stream-end": + if event["event_type"] == "COMPLETE": + break - async for chunk in response.aiter_lines(): - event = json.loads(chunk) - if event["event_type"] == "stream-end": - if event["event_type"] == "COMPLETE": - break + raise RagnaException(event["error_message"]) + if "text" in event: + yield cast(str, event["text"]) - raise RagnaException(event["error_message"]) - if "text" in event: - yield cast(str, event["text"]) - -class Command(CohereApiAssistant): +class Command(CohereAssistant): """ [Cohere Command](https://docs.cohere.com/docs/models#command) @@ -73,7 +68,7 @@ class Command(CohereApiAssistant): _MODEL = "command" -class CommandLight(CohereApiAssistant): +class CommandLight(CohereAssistant): """ [Cohere Command-Light](https://docs.cohere.com/docs/models#command) diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index afbb829a..8e1caf1e 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -3,7 +3,7 @@ from ragna._compat import anext from ragna.core import PackageRequirement, Requirement, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant # ijson does not support reading from an (async) iterator, but only from file-like @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes: return await anext(self._ait, b"") # type: ignore[call-arg] -class GoogleApiAssistant(ApiAssistant): +class GoogleAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" _MODEL: str @@ -48,8 +48,8 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ] ) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: import ijson @@ -88,7 +88,7 @@ async def _call_api( yield chunk -class GeminiPro(GoogleApiAssistant): +class GeminiPro(GoogleAssistant): """[Google Gemini Pro](https://ai.google.dev/models/gemini) !!! info "Required environment variables" @@ -103,7 +103,7 @@ class GeminiPro(GoogleApiAssistant): _MODEL = "gemini-pro" -class GeminiUltra(GoogleApiAssistant): +class GeminiUltra(GoogleAssistant): """[Google Gemini Ultra](https://ai.google.dev/models/gemini) !!! info "Required environment variables" diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py new file mode 100644 index 00000000..1151a62a --- /dev/null +++ b/ragna/assistants/_http_api.py @@ -0,0 +1,80 @@ +import contextlib +import json +import os +from typing import Any, AsyncIterator, Optional + +import httpx +from httpx import Response + +import ragna +from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement + + +class HttpApiAssistant(Assistant): + _API_KEY_ENV_VAR: Optional[str] + + @classmethod + def requirements(cls) -> list[Requirement]: + requirements: list[Requirement] = ( + [EnvVarRequirement(cls._API_KEY_ENV_VAR)] + if cls._API_KEY_ENV_VAR is not None + else [] + ) + requirements.extend(cls._extra_requirements()) + return requirements + + @classmethod + def _extra_requirements(cls) -> list[Requirement]: + return [] + + def __init__(self) -> None: + self._client = httpx.AsyncClient( + headers={"User-Agent": f"{ragna.__version__}/{self}"}, + timeout=60, + ) + self._api_key: Optional[str] = ( + os.environ[self._API_KEY_ENV_VAR] + if self._API_KEY_ENV_VAR is not None + else None + ) + + async def _assert_api_call_is_success(self, response: Response) -> None: + if response.is_success: + return + + content = await response.aread() + with contextlib.suppress(Exception): + content = json.loads(content) + + raise RagnaException( + "API call failed", + request_method=response.request.method, + request_url=str(response.request.url), + response_status_code=response.status_code, + response_content=content, + ) + + async def _stream_sse( + self, + method: str, + url: str, + **kwargs: Any, + ) -> AsyncIterator[dict[str, Any]]: + import httpx_sse + + async with httpx_sse.aconnect_sse( + self._client, method, url, **kwargs + ) as event_source: + await self._assert_api_call_is_success(event_source.response) + + async for sse in event_source.aiter_sse(): + yield json.loads(sse.data) + + async def _stream_jsonl( + self, method: str, url: str, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + async with self._client.stream(method, url, **kwargs) as response: + await self._assert_api_call_is_success(response) + + async for chunk in response.aiter_lines(): + yield json.loads(chunk) diff --git a/ragna/assistants/_llamafile.py b/ragna/assistants/_llamafile.py new file mode 100644 index 00000000..3e78a625 --- /dev/null +++ b/ragna/assistants/_llamafile.py @@ -0,0 +1,25 @@ +import os + +from ._openai import OpenaiCompliantHttpApiAssistant + + +class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): + """[llamafile](https://github.com/Mozilla-Ocho/llamafile) + + To use this assistant, start the llamafile server manually. By default, the server + is expected at `http://localhost:8080`. This can be changed with the + `RAGNA_LLAMAFILE_BASE_URL` environment variable. + + !!! info "Required packages" + + - `httpx_sse` + """ + + _API_KEY_ENV_VAR = None + _STREAMING_METHOD = "sse" + _MODEL = None + + @property + def _url(self) -> str: + base_url = os.environ.get("RAGNA_LLAMAFILE_BASE_URL", "http://localhost:8080") + return f"{base_url}/v1/chat/completions" diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index a9dad3ee..37957be2 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,22 +1,26 @@ -import json -from typing import AsyncIterator, cast +import abc +from typing import Any, AsyncIterator, Literal, Optional, cast -from ragna.core import PackageRequirement, Requirement, Source +from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class OpenaiApiAssistant(ApiAssistant): - _API_KEY_ENV_VAR = "OPENAI_API_KEY" - _MODEL: str +class OpenaiCompliantHttpApiAssistant(HttpApiAssistant): + _STREAMING_METHOD: Literal["sse", "jsonl"] + _MODEL: Optional[str] @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("httpx_sse")] + def requirements(cls) -> list[Requirement]: + requirements = super().requirements() + requirements.extend( + {"sse": [PackageRequirement("httpx_sse")]}.get(cls._STREAMING_METHOD, []) + ) + return requirements - @classmethod - def display_name(cls) -> str: - return f"OpenAI/{cls._MODEL}" + @property + @abc.abstractmethod + def _url(self) -> str: ... def _make_system_content(self, sources: list[Source]) -> str: # See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb @@ -27,50 +31,72 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + def _stream( + self, + method: str, + url: str, + **kwargs: Any, + ) -> AsyncIterator[dict[str, Any]]: + stream = { + "sse": self._stream_sse, + "jsonl": self._stream_jsonl, + }.get(self._STREAMING_METHOD) + if stream is None: + raise RagnaException + + return stream(method, url, **kwargs) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import httpx_sse - # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming - async with httpx_sse.aconnect_sse( - self._client, - "POST", - "https://api.openai.com/v1/chat/completions", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}", - }, - json={ - "messages": [ - { - "role": "system", - "content": self._make_system_content(sources), - }, - { - "role": "user", - "content": prompt, - }, - ], - "model": self._MODEL, - "temperature": 0.0, - "max_tokens": max_new_tokens, - "stream": True, - }, - ) as event_source: - await self._assert_api_call_is_success(event_source.response) - - async for sse in event_source.aiter_sse(): - data = json.loads(sse.data) - choice = data["choices"][0] - if choice["finish_reason"] is not None: - break - - yield cast(str, choice["delta"]["content"]) - - -class Gpt35Turbo16k(OpenaiApiAssistant): + headers = { + "Content-Type": "application/json", + } + if self._api_key is not None: + headers["Authorization"] = f"Bearer {self._api_key}" + + json_ = { + "messages": [ + { + "role": "system", + "content": self._make_system_content(sources), + }, + { + "role": "user", + "content": prompt, + }, + ], + "temperature": 0.0, + "max_tokens": max_new_tokens, + "stream": True, + } + if self._MODEL is not None: + json_["model"] = self._MODEL + + async for data in self._stream("POST", self._url, headers=headers, json=json_): + choice = data["choices"][0] + if choice["finish_reason"] is not None: + break + + yield cast(str, choice["delta"]["content"]) + + +class OpenaiAssistant(OpenaiCompliantHttpApiAssistant): + _API_KEY_ENV_VAR = "OPENAI_API_KEY" + _STREAMING_METHOD = "sse" + + @classmethod + def display_name(cls) -> str: + return f"OpenAI/{cls._MODEL}" + + @property + def _url(self) -> str: + return "https://api.openai.com/v1/chat/completions" + + +class Gpt35Turbo16k(OpenaiAssistant): """[OpenAI GPT-3.5](https://platform.openai.com/docs/models/gpt-3-5) !!! info "Required environment variables" @@ -85,7 +111,7 @@ class Gpt35Turbo16k(OpenaiApiAssistant): _MODEL = "gpt-3.5-turbo-16k" -class Gpt4(OpenaiApiAssistant): +class Gpt4(OpenaiAssistant): """[OpenAI GPT-4](https://platform.openai.com/docs/models/gpt-4) !!! info "Required environment variables" diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 97961456..02b964b5 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -4,21 +4,24 @@ from ragna import assistants from ragna._compat import anext -from ragna.assistants._api import ApiAssistant +from ragna.assistants._http_api import HttpApiAssistant from ragna.core import RagnaException from tests.utils import skip_on_windows -API_ASSISTANTS = [ +HTTP_API_ASSISTANTS = [ assistant for assistant in assistants.__dict__.values() if isinstance(assistant, type) - and issubclass(assistant, ApiAssistant) - and assistant is not ApiAssistant + and issubclass(assistant, HttpApiAssistant) + and assistant is not HttpApiAssistant ] @skip_on_windows -@pytest.mark.parametrize("assistant", API_ASSISTANTS) +@pytest.mark.parametrize( + "assistant", + [assistant for assistant in HTTP_API_ASSISTANTS if assistant._API_KEY_ENV_VAR], +) async def test_api_call_error_smoke(mocker, assistant): mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"})