Skip to content

Commit

Permalink
[ENH] Add support for Ollama assistants (#376)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
  • Loading branch information
2 people authored and blakerosenthal committed Jul 17, 2024
1 parent 34c26cd commit eb486f0
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 126 deletions.
8 changes: 8 additions & 0 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
# - [ragna.assistants.Gpt4][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]

from ragna import assistants

Expand Down
8 changes: 8 additions & 0 deletions docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@
# - [ragna.assistants.Jurassic2Ultra][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]
#
# !!! note
#
Expand Down
16 changes: 16 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
"CommandLight",
"GeminiPro",
"GeminiUltra",
"OllamaGemma2B",
"OllamaPhi2",
"OllamaLlama2",
"OllamaLlava",
"OllamaMistral",
"OllamaMixtral",
"OllamaOrcaMini",
"Gpt35Turbo16k",
"Gpt4",
"Jurassic2Ultra",
Expand All @@ -19,6 +26,15 @@
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._llamafile import LlamafileAssistant
from ._ollama import (
OllamaGemma2B,
OllamaLlama2,
OllamaLlava,
OllamaMistral,
OllamaMixtral,
OllamaOrcaMini,
OllamaPhi2,
)
from ._openai import Gpt4, Gpt35Turbo16k

# isort: split
Expand Down
10 changes: 5 additions & 5 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class Ai21LabsAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_STREAMING_PROTOCOL = None
_MODEL_TYPE: str

@classmethod
Expand All @@ -27,7 +28,8 @@ async def answer(
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
response = await self._client.post(
async for data in self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
"accept": "application/json",
Expand All @@ -46,10 +48,8 @@ async def answer(
],
"system": self._make_system_content(sources),
},
)
await self._assert_api_call_is_success(response)

yield cast(str, response.json()["outputs"][0]["text"])
):
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class AnthropicAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.SSE
_MODEL: str

@classmethod
Expand Down Expand Up @@ -40,7 +41,7 @@ async def answer(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
async for data in self._stream_sse(
async for data in self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ragna.core import RagnaException, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class CohereAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL
_MODEL: str

@classmethod
Expand All @@ -29,7 +30,7 @@ async def answer(
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async for event in self._stream_jsonl(
async for event in self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand Down
49 changes: 11 additions & 38 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,15 @@
from typing import AsyncIterator

from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source
from ragna.core import Source

from ._http_api import HttpApiAssistant


# ijson does not support reading from an (async) iterator, but only from file-like
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
# See https://github.com/ICRAR/ijson/issues/44 for details.
# ijson actually doesn't care about most of the file interface and only requires the
# read() method to be present.
class AsyncIteratorReader:
def __init__(self, ait: AsyncIterator[bytes]) -> None:
self._ait = ait

async def read(self, n: int) -> bytes:
# n is usually used to indicate how many bytes to read, but since we want to
# return a chunk as soon as it is available, we ignore the value of n. The only
# exception is n == 0, which is used by ijson to probe the return type and
# set up decoding.
if n == 0:
return b""
return await anext(self._ait, b"") # type: ignore[call-arg]
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class GoogleAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSON
_MODEL: str

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("ijson")]

@classmethod
def display_name(cls) -> str:
return f"Google/{cls._MODEL}"
Expand All @@ -51,9 +28,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import ijson

async with self._client.stream(
async for chunk in self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand All @@ -64,7 +39,10 @@ async def answer(
],
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
{
"category": f"HARM_CATEGORY_{category}",
"threshold": "BLOCK_NONE",
}
for category in [
"HARASSMENT",
"HATE_SPEECH",
Expand All @@ -78,14 +56,9 @@ async def answer(
"maxOutputTokens": max_new_tokens,
},
},
) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
AsyncIteratorReader(response.aiter_bytes(1024)),
"item.candidates.item.content.parts.item.text",
):
yield chunk
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
):
yield chunk


class GeminiPro(GoogleAssistant):
Expand Down
Loading

0 comments on commit eb486f0

Please sign in to comment.