Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix streaming handling for builtin assistants #462

Merged
merged 8 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ dependencies:
- pip
- git-lfs
- pip:
- httpx_sse
- ijson
- sse_starlette
- python-dotenv
- pytest >=6
- pytest-mock
Expand Down
2 changes: 1 addition & 1 deletion ragna/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def timeout_after(
seconds: float = 30, *, message: str = ""
) -> Callable[[Callable], Callable]:
timeout = f"Timeout after {seconds:.1f} seconds"
message = timeout if message else f"{timeout}: {message}"
message = f"{timeout}: {message}" if message else timeout

def decorator(fn: Callable) -> Callable:
if is_debugging():
Expand Down
7 changes: 4 additions & 3 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def answer(
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
async with self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
Expand All @@ -49,8 +49,9 @@ async def answer(
],
"system": self._make_system_content(sources),
},
):
yield cast(str, data["outputs"][0]["text"])
) as stream:
async for data in stream:
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
23 changes: 12 additions & 11 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def answer(
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
async with self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand All @@ -59,16 +59,17 @@ async def answer(
"temperature": 0.0,
"stream": True,
},
):
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))
) as stream:
async for data in stream:
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))


class ClaudeOpus(AnthropicAssistant):
Expand Down
19 changes: 10 additions & 9 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def answer(
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
prompt, sources = (message := messages[-1]).content, message.sources
async for event in self._call_api(
async with self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand All @@ -48,14 +48,15 @@ async def answer(
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
},
):
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])
) as stream:
async for event in stream:
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])


class Command(CohereAssistant):
Expand Down
7 changes: 4 additions & 3 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for chunk in self._call_api(
async with self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand Down Expand Up @@ -58,8 +58,9 @@ async def answer(
},
},
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
):
yield chunk
) as stream:
async for chunk in stream:
yield chunk


class GeminiPro(GoogleAssistant):
Expand Down
40 changes: 29 additions & 11 deletions ragna/assistants/_http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import json
import os
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncContextManager, AsyncIterator, Optional

import httpx

Expand Down Expand Up @@ -47,7 +47,7 @@ def __call__(
*,
parse_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
) -> AsyncContextManager[AsyncIterator[Any]]:
if self._protocol is None:
call_method = self._no_stream
else:
Expand All @@ -56,8 +56,10 @@ def __call__(
HttpStreamingProtocol.JSONL: self._stream_jsonl,
HttpStreamingProtocol.JSON: self._stream_json,
}[self._protocol]

return call_method(method, url, parse_kwargs=parse_kwargs or {}, **kwargs)

@contextlib.asynccontextmanager
async def _no_stream(
self,
method: str,
Expand All @@ -68,8 +70,13 @@ async def _no_stream(
) -> AsyncIterator[Any]:
response = await self._client.request(method, url, **kwargs)
await self._assert_api_call_is_success(response)
yield response.json()

async def stream() -> AsyncIterator[Any]:
yield response.json()

yield stream()

@contextlib.asynccontextmanager
async def _stream_sse(
self,
method: str,
Expand All @@ -85,9 +92,13 @@ async def _stream_sse(
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
yield json.loads(sse.data)
async def stream() -> AsyncIterator[Any]:
async for sse in event_source.aiter_sse():
yield json.loads(sse.data)

yield stream()

@contextlib.asynccontextmanager
async def _stream_jsonl(
self,
method: str,
Expand All @@ -99,8 +110,11 @@ async def _stream_jsonl(
async with self._client.stream(method, url, **kwargs) as response:
await self._assert_api_call_is_success(response)

async for chunk in response.aiter_lines():
yield json.loads(chunk)
async def stream() -> AsyncIterator[Any]:
async for chunk in response.aiter_lines():
yield json.loads(chunk)

yield stream()

# ijson does not support reading from an (async) iterator, but only from file-like
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
Expand All @@ -120,6 +134,7 @@ async def read(self, n: int) -> bytes:
return b""
return await anext(self._ait, b"") # type: ignore[call-arg]

@contextlib.asynccontextmanager
async def _stream_json(
self,
method: str,
Expand All @@ -136,10 +151,13 @@ async def _stream_json(
async with self._client.stream(method, url, **kwargs) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item
):
yield chunk
async def stream() -> AsyncIterator[Any]:
async for chunk in ijson.items(
self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item
):
yield chunk

yield stream()

async def _assert_api_call_is_success(self, response: httpx.Response) -> None:
if response.is_success:
Expand Down
17 changes: 10 additions & 7 deletions ragna/assistants/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
if "error" in data:
raise RagnaException(data["error"])
if not data["done"]:
yield cast(str, data["message"]["content"])
async with self._call_openai_api(
prompt, sources, max_new_tokens=max_new_tokens
) as stream:
async for data in stream:
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
if "error" in data:
raise RagnaException(data["error"])
if not data["done"]:
yield cast(str, data["message"]["content"])


class OllamaGemma2B(OllamaAssistant):
Expand Down
21 changes: 12 additions & 9 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from functools import cached_property
from typing import Any, AsyncIterator, Optional, cast
from typing import Any, AsyncContextManager, AsyncIterator, Optional, cast

from ragna.core import Message, Source

Expand All @@ -23,9 +23,9 @@ def _make_system_content(self, sources: list[Source]) -> str:
)
return instruction + "\n\n".join(source.content for source in sources)

def _stream(
def _call_openai_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[dict[str, Any]]:
) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]:
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/streaming
headers = {
Expand Down Expand Up @@ -58,12 +58,15 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break

yield cast(str, choice["delta"]["content"])
async with self._call_openai_api(
prompt, sources, max_new_tokens=max_new_tokens
) as stream:
async for data in stream:
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break

yield cast(str, choice["delta"]["content"])


class OpenaiAssistant(OpenaiLikeHttpApiAssistant):
Expand Down
49 changes: 49 additions & 0 deletions tests/assistants/streaming_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import random

import sse_starlette
from fastapi import FastAPI, Request, Response, status
from fastapi.responses import StreamingResponse

app = FastAPI()


@app.get("/health")
async def health():
return Response(b"", status_code=status.HTTP_200_OK)


@app.post("/sse")
async def sse(request: Request):
data = await request.json()

async def stream():
for obj in data:
yield sse_starlette.ServerSentEvent(json.dumps(obj))

return sse_starlette.EventSourceResponse(stream())


@app.post("/jsonl")
async def jsonl(request: Request):
data = await request.json()

async def stream():
for obj in data:
yield f"{json.dumps(obj)}\n"

return StreamingResponse(stream())


@app.post("/json")
async def json_(request: Request):
data = await request.body()

async def stream():
start = 0
while start < len(data):
end = start + random.randint(1, 10)
yield data[start:end]
start = end

return StreamingResponse(stream())
Loading
Loading