Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to control version of astream events API #760

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
) -> None:
"""Create an API handler for the given runnable.
Expand Down Expand Up @@ -595,6 +596,8 @@ def __init__(
If not provided, then all logs will be allowed to be streamed.
Use to also limit the events that can be streamed by the stream_events.
TODO: Introduce deprecation for this parameter to rename it
astream_events_version: version of the stream events endpoint to use.
By default "v2".
"""
if importlib.util.find_spec("sse_starlette") is None:
raise ImportError(
Expand Down Expand Up @@ -632,6 +635,7 @@ def __init__(
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
self._astream_events_version = astream_events_version

if token_feedback_config:
if len(token_feedback_config["key_configs"]) != 1:
Expand Down Expand Up @@ -1343,7 +1347,7 @@ async def _stream_events() -> AsyncIterator[dict]:
exclude_names=stream_events_request.exclude_names,
exclude_types=stream_events_request.exclude_types,
exclude_tags=stream_events_request.exclude_tags,
version="v1",
version=self._astream_events_version,
):
if (
self._names_in_stream_allow_list is None
Expand Down
22 changes: 17 additions & 5 deletions langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def _log_error_message_once(error_message: str) -> None:
logger.error(error_message)


@lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages
def _log_info_message_once(error_message: str) -> None:
"""Log an error message once."""
logger.info(error_message)


def _sanitize_request(request: httpx.Request) -> httpx.Request:
"""Remove sensitive headers from the request."""
accept_headers = {
Expand Down Expand Up @@ -752,7 +758,7 @@ async def astream_events(
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
version: Literal["v1", "v2", None] = None,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
Expand All @@ -775,21 +781,27 @@ async def astream_events(
input: The input to the runnable
config: The config to use for the runnable
version: The version of the astream_events to use.
Currently only "v1" is supported.
Currently, this input is IGNORED on the client.
The server will return whatever format it's configured with.
include_names: The names of the events to include
include_types: The types of the events to include
include_tags: The tags of the events to include
exclude_names: The names of the events to exclude
exclude_types: The types of the events to exclude
exclude_tags: The tags of the events to exclude
"""
if version != "v1":
raise ValueError(f"Unsupported version: {version}. Use 'v1'")

# Create a stream handler that will emit Log objects
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)

if version is not None:
_log_info_message_once(
"Versioning of the astream_events API is not supported on the client "
"side currently. The server will return events in whatever format "
"it was configured with in add_routes or APIHandler. "
"To stop seeing this message, remove the `version` argument."
)

events = []

run_manager = await callback_manager.on_chain_start(
Expand Down
11 changes: 7 additions & 4 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def add_routes(
enabled_endpoints: Optional[Sequence[EndpointName]] = None,
dependencies: Optional[Sequence[Depends]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
Expand Down Expand Up @@ -380,14 +381,16 @@ def add_routes(
- chat: UX is optimized for chat-like interactions. Please review
the README in langserve for more details about constraints (e.g.,
which message types are supported etc.)
astream_events_version: version of the stream events endpoint to use.
By default "v2".
""" # noqa: E501
if not isinstance(runnable, Runnable):
raise TypeError(
f"Expected a Runnable, got {type(runnable)}. "
f"The second argument to add_routes should be a Runnable instance."
f"add_route(app, runnable, ...) is the correct usage."
f"Please make sure that you are using a runnable which is an instance of "
f"langchain_core.runnables.Runnable."
"The second argument to add_routes should be a Runnable instance."
"add_route(app, runnable, ...) is the correct usage."
"Please make sure that you are using a runnable which is an instance of "
"langchain_core.runnables.Runnable."
)

endpoint_configuration = _EndpointConfiguration(
Expand Down
45 changes: 19 additions & 26 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,7 @@ def mul_two(y: int) -> int:
# test client side error
with pytest.raises(httpx.HTTPStatusError) as cb:
# Invalid input type (expected string but got int)
async for _ in runnable.astream_events("foo", version="v1"):
async for _ in runnable.astream_events("foo", version="v2"):
pass

# Verify that this is a 422 error
Expand All @@ -2315,7 +2315,7 @@ def mul_two(y: int) -> int:
with pytest.raises(httpx.HTTPStatusError) as cb:
# Invalid input type (expected string but got int)
# include names should not be a list of lists
async for _ in runnable.astream_events(1, include_names=[[]], version="v1"):
async for _ in runnable.astream_events(1, include_names=[[]], version="v2"):
pass

# Verify that this is a 422 error
Expand All @@ -2324,7 +2324,7 @@ def mul_two(y: int) -> int:
# Test good requests
events = []

async for event in runnable.astream_events(1, version="v1"):
async for event in runnable.astream_events(1, version="v2"):
events.append(event)

# validate events
Expand All @@ -2337,6 +2337,7 @@ def mul_two(y: int) -> int:
assert not k.startswith("__")
assert "metadata" in event
del event["metadata"]
event["parent_ids"] = []

assert events == [
{
Expand Down Expand Up @@ -2416,6 +2417,7 @@ def _clean_up_events(events: List[Dict[str, Any]]) -> None:
assert not k.startswith("__")
assert "metadata" in event
del event["metadata"]
event["parent_ids"] = []


async def test_astream_events_with_serialization(
Expand Down Expand Up @@ -2488,7 +2490,7 @@ def back_to_serializable(inputs) -> str:
app, raise_app_exceptions=False, path="/doc_types"
) as runnable:
# Test good requests
events = [event async for event in runnable.astream_events("foo", version="v1")]
events = [event async for event in runnable.astream_events("foo", version="v2")]
_clean_up_events(events)

assert events == [
Expand Down Expand Up @@ -2578,7 +2580,7 @@ def back_to_serializable(inputs) -> str:
app, raise_app_exceptions=False, path="/get_pets"
) as runnable:
# Test good requests
events = [event async for event in runnable.astream_events("foo", version="v1")]
events = [event async for event in runnable.astream_events("foo", version="v2")]
_clean_up_events(events)
assert events == [
{
Expand Down Expand Up @@ -2613,7 +2615,7 @@ def back_to_serializable(inputs) -> str:
) as runnable:
# Test good requests
with pytest.raises(httpx.HTTPStatusError) as cb:
async for event in runnable.astream_events("foo", version="v1"):
async for event in runnable.astream_events("foo", version="v2"):
pass
assert cb.value.response.status_code == 500

Expand Down Expand Up @@ -2641,7 +2643,7 @@ async def test_astream_events_with_prompt_model_parser_chain(
events = [
event
async for event in runnable.astream_events(
{"question": "hello"}, version="v1"
{"question": "hello"}, version="v2"
)
]
_clean_up_events(events)
Expand Down Expand Up @@ -2850,25 +2852,16 @@ async def test_astream_events_with_prompt_model_parser_chain(
]
},
"output": {
"generations": [
[
{
"generation_info": None,
"message": {
"additional_kwargs": {},
"content": "Hello World!",
"name": None,
"response_metadata": {},
"type": "AIMessageChunk",
},
"text": "Hello World!",
"type": "ChatGenerationChunk",
}
]
],
"llm_output": None,
"run": None,
"type": "LLMResult",
"additional_kwargs": {},
"content": "Hello World!",
"example": False,
"invalid_tool_calls": [],
"name": None,
"response_metadata": {},
"tool_call_chunks": [],
"tool_calls": [],
"type": "AIMessageChunk",
"usage_metadata": None,
},
},
"event": "on_chat_model_end",
Expand Down
Loading