Skip to content

Commit

Permalink
Add StreamEventRequest type, update OpenAPIDocs (#422)
Browse files Browse the repository at this point in the history
This PR adds a type for StreamEventsRequest and updates OpenAPI docs.
  • Loading branch information
eyurtsev authored Jan 26, 2024
1 parent f41fb26 commit 02319e5
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 1 deletion.
10 changes: 10 additions & 0 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
create_batch_response_model,
create_invoke_request_model,
create_invoke_response_model,
create_stream_events_request_model,
create_stream_log_request_model,
create_stream_request_model,
)
Expand Down Expand Up @@ -590,6 +591,10 @@ def __init__(
self._StreamLogRequest = create_stream_log_request_model(
model_namespace, input_type_, self._ConfigPayload
)

self._StreamEventsRequest = create_stream_events_request_model(
model_namespace, input_type_, self._ConfigPayload
)
# Generate the response models
self._InvokeResponse = create_invoke_response_model(
model_namespace, output_type_
Expand All @@ -616,6 +621,11 @@ def StreamLogRequest(self) -> Type[BaseModel]:
"""Return the stream log request model."""
return self._StreamLogRequest

@property
def StreamEventsRequest(self) -> Type[BaseModel]:
"""Return the stream events request model."""
return self._StreamEventsRequest

@property
def InvokeResponse(self) -> Type[BaseModel]:
"""Return the invoke response model."""
Expand Down
106 changes: 105 additions & 1 deletion langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,15 @@ def add_routes(
# Determine the base URL for the playground endpoint
prefix = app.prefix if isinstance(app, APIRouter) else "" # type: ignore

# Let's check if the runnable has a astream events property
# It's a new method on runnables that allows them to stream events.
# We'll only add this if folks are on recent versions of langchain-core.
# This is done so that folks can upgrade langserve without having to
# upgrade langchain-core if they need other fixes.
# We can likely remove in a few months and bump minimal version of langchain
# required by langserve.
has_astream_events = hasattr(runnable, "astream_events")

api_handler = APIHandler(
runnable,
path=path,
Expand Down Expand Up @@ -549,7 +558,7 @@ async def stream_log_with_config(
# that are used by the runnnable (e.g., input, config fields)
return await api_handler.stream_log(request, config_hash=config_hash)

if endpoint_configuration.is_stream_events_enabled:
if has_astream_events and endpoint_configuration.is_stream_events_enabled:

@app.post(
f"{namespace}/stream_events",
Expand Down Expand Up @@ -690,6 +699,7 @@ async def config_schema_with_config(
BatchResponse = api_handler.BatchResponse
StreamRequest = api_handler.StreamRequest
StreamLogRequest = api_handler.StreamLogRequest
StreamEventsRequest = api_handler.StreamEventsRequest

if endpoint_configuration.is_invoke_enabled:

Expand Down Expand Up @@ -913,3 +923,97 @@ async def _stream_log_docs(
),
dependencies=dependencies,
)(_stream_log_docs)

if has_astream_events and endpoint_configuration.is_stream_events_enabled:

async def _stream_events_docs(
stream_events_request: Annotated[
StreamEventsRequest, StreamEventsRequest
],
config_hash: str = "",
) -> EventSourceResponse:
"""Stream events from the given runnable.
This endpoint allows to stream events from the runnable, including
events from all intermediate steps.
**Attention**
This is a new endpoint that only works for langchain-core >= 0.1.14.
It belongs to a Beta API that may change in the future.
**Important**
Specify filters to the events you want to receive by setting
the appropriate filters in the request body.
This will help avoid sending too much data over the network.
It will also prevent serialization issues with
any unsupported types since it won't need to serialize events
that aren't transmitted.
The endpoint uses a server sent event stream to stream the output.
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
The encoding of events follows the following format:
* data - for streaming the output of the runnable
{
"event": "data",
"data": {
...
}
}
* error - for signaling an error in the stream, also ends the stream.
{
"event": "error",
"data": {
"status_code": 500,
"message": "Internal Server Error"
}
}
* end - for signaling the end of the stream.
This helps the client to know when to stop listening for events and
know that the streaming has ended successfully.
{
"event": "end",
}
`data` for the `data` event is a JSON object that corresponds
to a serialized representation of a StreamEvent.
See LangChain documentation for more information about astream_events.
"""
raise AssertionError("This endpoint should not be reachable.")

app.post(
f"{namespace}/stream_events",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream_events"),
dependencies=dependencies,
)(_stream_events_docs)

if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/stream_events",
include_in_schema=True,
tags=route_tags_with_config,
name=_route_name_with_config("stream_events"),
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream_events endpoint "
"without the `c/{config_hash}` path parameter."
),
dependencies=dependencies,
)(_stream_events_docs)
58 changes: 58 additions & 0 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,64 @@ def create_stream_log_request_model(
return stream_log_request


def create_stream_events_request_model(
namespace: str,
input_type: Validator,
config: Type[BaseModel],
) -> Type[BaseModel]:
"""Create a pydantic model for the stream events request."""
stream_events_request = create_model(
f"{namespace}StreamEventsRequest",
input=(input_type, ...),
config=(config, Field(default_factory=dict)),
include_names=(
Optional[Sequence[str]],
Field(
None,
description="If specified, filter to runnables with matching names",
),
),
include_types=(
Optional[Sequence[str]],
Field(
None,
description="If specified, filter to runnables with matching types",
),
),
include_tags=(
Optional[Sequence[str]],
Field(
None,
description="If specified, filter to runnables with matching tags",
),
),
exclude_names=(
Optional[Sequence[str]],
Field(
None,
description="If specified, exclude runnables with matching names",
),
),
exclude_types=(
Optional[Sequence[str]],
Field(
None,
description="If specified, exclude runnables with matching types",
),
),
exclude_tags=(
Optional[Sequence[str]],
Field(
None,
description="If specified, exclude runnables with matching tags",
),
),
kwargs=(dict, Field(default_factory=dict)),
)
stream_events_request.update_forward_refs()
return stream_events_request


class InvokeBaseResponse(BaseModel):
"""Base class for invoke request."""

Expand Down

0 comments on commit 02319e5

Please sign in to comment.