From 9ca408610a7293541b5063136ec308801d3e6752 Mon Sep 17 00:00:00 2001 From: Takuya Igei <28337009+itok01@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:11:26 +0900 Subject: [PATCH 1/4] add: Support for Server-Sent Events in APIHandler This commit adds support for Server-Sent Events (SSE) in the APIHandler class. It imports the necessary modules from `sse_starlette` and includes a check for the `ServerSentEvent` class. If SSE is supported and the chunk is an instance of `ServerSentEvent`, it yields the chunk. This change improves the functionality of the APIHandler by enabling SSE support. --- langserve/api_handler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index d6b24b58..dd1cf04c 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -78,9 +78,10 @@ from langserve.version import __version__ try: - from sse_starlette import EventSourceResponse + from sse_starlette import EventSourceResponse, ServerSentEvent except ImportError: EventSourceResponse = Any + ServerSentEvent = Any def _is_hosted() -> bool: @@ -1111,7 +1112,7 @@ async def stream( feedback_key = None task = None - async def _stream() -> AsyncIterator[dict]: + async def _stream() -> AsyncIterator[dict | ServerSentEvent]: """Stream the output of the runnable.""" try: config_w_callbacks = config.copy() @@ -1136,6 +1137,12 @@ async def _stream() -> AsyncIterator[dict]: run_id, feedback_key, feedback_token ) + if ServerSentEvent is not Any and isinstance( + chunk, ServerSentEvent + ): + yield chunk + continue + yield { # EventSourceResponse expects a string for data # so after serializing into bytes, we decode into utf-8 From 8e153641b3ef798919b63215cdf43dfda77bb898 Mon Sep 17 00:00:00 2001 From: Takuya Igei <28337009+itok01@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:39:47 +0900 Subject: [PATCH 2/4] feat: yield serialized Server-Sent Events in APIHandler --- langserve/api_handler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index dd1cf04c..8e684766 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -1140,7 +1140,10 @@ async def _stream() -> AsyncIterator[dict | ServerSentEvent]: if ServerSentEvent is not Any and isinstance( chunk, ServerSentEvent ): - yield chunk + yield { + "event": chunk.event, + "data": self._serializer.dumps(chunk.data).decode("utf-8"), + } continue yield { From e355873cbfc8a102dc69fde7c9feafb912057c2b Mon Sep 17 00:00:00 2001 From: Takuya Igei <28337009+itok01@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:40:16 +0900 Subject: [PATCH 3/4] add: example for custom events in LangChain Server --- examples/custom_events/server.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100755 examples/custom_events/server.py diff --git a/examples/custom_events/server.py b/examples/custom_events/server.py new file mode 100755 index 00000000..32e1e092 --- /dev/null +++ b/examples/custom_events/server.py @@ -0,0 +1,67 @@ +""" +Allows the `/stream` endpoint to return `sse_starlette.ServerSentEvent` from runnable, +allowing you to return custom events such as `event: error`. +""" + +from typing import Any, AsyncIterator, Dict + +from fastapi import FastAPI +from langchain_core.runnables import RunnableConfig, RunnableLambda +from sse_starlette import ServerSentEvent + +from langserve import add_routes +from langserve.pydantic_v1 import BaseModel + +app = FastAPI( + title="LangChain Server", + version="1.0", + description="Spin up a simple api server using Langchain's Runnable interfaces", +) + + +class InputType(BaseModel): ... + + +class OutputType(BaseModel): + message: str + + +async def error_event( + _: InputType, + config: RunnableConfig, +) -> AsyncIterator[Dict[str, Any] | ServerSentEvent]: + for i in range(4): + yield { + "message": f"Message {i}", + } + + is_streaming = False + if "metadata" in config: + metadata = config["metadata"] + if "__langserve_endpoint" in metadata: + is_streaming = metadata["__langserve_endpoint"] == "stream" + + if is_streaming: + yield ServerSentEvent( + data={ + "message": "An error occurred", + }, + event="error", + ) + else: + yield { + "message": "An error occurred", + } + + +add_routes( + app, + RunnableLambda(error_event), + input_type=InputType, + output_type=OutputType, +) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="localhost", port=8000) From f8fc2455b070970b00288ab29e55259b38fad767 Mon Sep 17 00:00:00 2001 From: Takuya Igei <28337009+itok01@users.noreply.github.com> Date: Thu, 4 Jul 2024 20:16:08 +0900 Subject: [PATCH 4/4] update: support ServerSentEvent in `/invoke` --- langserve/api_handler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 8e684766..5a879f16 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -834,6 +834,9 @@ async def invoke( output = await invoke_coro feedback_token = None + if ServerSentEvent is not Any and isinstance(output, ServerSentEvent): + output = output.data + if self._include_callback_events: callback_events = [ _scrub_exceptions_in_event(event)