diff --git a/examples/custom_events/server.py b/examples/custom_events/server.py new file mode 100755 index 00000000..32e1e092 --- /dev/null +++ b/examples/custom_events/server.py @@ -0,0 +1,67 @@ +""" +Allows the `/stream` endpoint to return `sse_starlette.ServerSentEvent` from runnable, +allowing you to return custom events such as `event: error`. +""" + +from typing import Any, AsyncIterator, Dict + +from fastapi import FastAPI +from langchain_core.runnables import RunnableConfig, RunnableLambda +from sse_starlette import ServerSentEvent + +from langserve import add_routes +from langserve.pydantic_v1 import BaseModel + +app = FastAPI( + title="LangChain Server", + version="1.0", + description="Spin up a simple api server using Langchain's Runnable interfaces", +) + + +class InputType(BaseModel): ... + + +class OutputType(BaseModel): + message: str + + +async def error_event( + _: InputType, + config: RunnableConfig, +) -> AsyncIterator[Dict[str, Any] | ServerSentEvent]: + for i in range(4): + yield { + "message": f"Message {i}", + } + + is_streaming = False + if "metadata" in config: + metadata = config["metadata"] + if "__langserve_endpoint" in metadata: + is_streaming = metadata["__langserve_endpoint"] == "stream" + + if is_streaming: + yield ServerSentEvent( + data={ + "message": "An error occurred", + }, + event="error", + ) + else: + yield { + "message": "An error occurred", + } + + +add_routes( + app, + RunnableLambda(error_event), + input_type=InputType, + output_type=OutputType, +) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="localhost", port=8000) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index d6b24b58..5a879f16 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -78,9 +78,10 @@ from langserve.version import __version__ try: - from sse_starlette import EventSourceResponse + from sse_starlette import EventSourceResponse, ServerSentEvent except ImportError: EventSourceResponse = Any + ServerSentEvent = Any def _is_hosted() -> bool: @@ -833,6 +834,9 @@ async def invoke( output = await invoke_coro feedback_token = None + if ServerSentEvent is not Any and isinstance(output, ServerSentEvent): + output = output.data + if self._include_callback_events: callback_events = [ _scrub_exceptions_in_event(event) @@ -1111,7 +1115,7 @@ async def stream( feedback_key = None task = None - async def _stream() -> AsyncIterator[dict]: + async def _stream() -> AsyncIterator[dict | ServerSentEvent]: """Stream the output of the runnable.""" try: config_w_callbacks = config.copy() @@ -1136,6 +1140,15 @@ async def _stream() -> AsyncIterator[dict]: run_id, feedback_key, feedback_token ) + if ServerSentEvent is not Any and isinstance( + chunk, ServerSentEvent + ): + yield { + "event": chunk.event, + "data": self._serializer.dumps(chunk.data).decode("utf-8"), + } + continue + yield { # EventSourceResponse expects a string for data # so after serializing into bytes, we decode into utf-8