From b3e728d95d86db726cdaac3427979dd85a3f8410 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 20 Dec 2023 17:44:55 -0500 Subject: [PATCH] Only run per request modifier for invoke,batch,stream,astream_log (#341) This only runs the per request modifier for the endpoints where one would expect it to be run. If there is another use case, please file an issue explaining it and we'll accommodate it. https://github.com/langchain-ai/langserve/issues/325 --- langserve/api_handler.py | 28 +++++++++++-- langserve/server.py | 3 ++ tests/unit_tests/test_server_client.py | 55 ++++++++++++++++++++++++-- 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index af1844dc..8139a49d 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -1078,7 +1078,12 @@ async def input_schema( config_keys=self._config_keys, model=self._ConfigPayload, request=request, - per_req_config_modifier=self._per_req_config_modifier, + # Do not use per request config modifier for output schema + # since it's unclear why it would make sense to modify + # this using a per request config modifier. + # If this is needed, for some reason please file an issue explaining + # the user case. + per_req_config_modifier=None, server_config=server_config, ) config = _update_config_with_defaults( @@ -1101,7 +1106,12 @@ async def output_schema( config_keys=self._config_keys, model=self._ConfigPayload, request=request, - per_req_config_modifier=self._per_req_config_modifier, + # Do not use per request config modifier for output schema + # since it's unclear why it would make sense to modify + # this using a per request config modifier. + # If this is needed, for some reason please file an issue explaining + # the user case. + per_req_config_modifier=None, server_config=server_config, ) config = _update_config_with_defaults( @@ -1123,7 +1133,12 @@ async def config_schema( config_keys=self._config_keys, model=self._ConfigPayload, request=request, - per_req_config_modifier=self._per_req_config_modifier, + # Do not use per request config modifier for output schema + # since it's unclear why it would make sense to modify + # this using a per request config modifier. + # If this is needed, for some reason please file an issue explaining + # the user case. + per_req_config_modifier=None, server_config=server_config, ) config = _update_config_with_defaults( @@ -1150,7 +1165,12 @@ async def playground( config_keys=self._config_keys, model=self._ConfigPayload, request=request, - per_req_config_modifier=self._per_req_config_modifier, + # Do not use per request config modifier for output schema + # since it's unclear why it would make sense to modify + # this using a per request config modifier. + # If this is needed, for some reason please file an issue explaining + # the user case. + per_req_config_modifier=None, server_config=server_config, ) diff --git a/langserve/server.py b/langserve/server.py index 7f27df00..2bb96de3 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -274,6 +274,9 @@ def add_routes( for example, if the user wants to pass in a header containing credentials to a runnable. The RunnableConfig is presented in its dictionary form. Note that only keys in `config_keys` will be modifiable by this function. + As of 0.0.37, this function is only called for the invoke, batch, stream, + and stream_log endpoints. This function is not called for the playground, + input_schema, output_schema, and config_schema endpoints etc. enable_feedback_endpoint: Whether to enable an endpoint for logging feedback to LangSmith. Enabled by default. If this flag is disabled or LangSmith tracing is not enabled for the runnable, then 400 errors will be thrown diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 4c22ac4d..2dd38c64 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -1778,9 +1778,7 @@ async def test_enforce_trailing_slash_in_client() -> None: assert r.url == "nosuchurl/" -async def test_per_request_config_modifier( - event_loop: AbstractEventLoop, mocker: MockerFixture -) -> None: +async def test_per_request_config_modifier(event_loop: AbstractEventLoop) -> None: """Test updating the config based on the raw request object.""" async def add_one(x: int) -> int: @@ -1810,6 +1808,57 @@ def header_passthru_modifier( per_req_config_modifier=header_passthru_modifier, ) + # this test verifies that per request modifier is only + # applied for the expected endpoints + def buggy_modifier(config: Dict[str, Any], request: Request) -> Dict[str, Any]: + """Update the config""" + raise ValueError("oops I did it again") + + add_routes( + app, + server_runnable, + path="/with_buggy_modifier", + per_req_config_modifier=buggy_modifier, + ) + + async with get_async_test_client( + app, + raise_app_exceptions=False, + ) as async_client: + endpoints_to_test = ( + "invoke", + "batch", + "stream", + "stream_log", + "input_schema", + "output_schema", + "config_schema", + "playground/index.html", + ) + + for endpoint in endpoints_to_test: + url = "/with_buggy_modifier/" + endpoint + + if endpoint == "batch": + payload = {"inputs": [1, 2]} + response = await async_client.post(url, json=payload) + elif endpoint in {"invoke", "stream", "stream_log"}: + payload = {"input": 1} + response = await async_client.post(url, json=payload) + elif endpoint in {"input_schema", "output_schema", "config_schema"}: + response = await async_client.get(url) + elif endpoint == "playground/index.html": + response = await async_client.get(url) + else: + raise ValueError(f"Unknown endpoint {endpoint}") + + if endpoint in {"invoke", "batch"}: + assert response.status_code == 500 + elif endpoint in {"stream", "stream_log"}: + assert '"status_code": 500' in response.text + else: + assert response.status_code != 500 + async def test_uuid_serialization(event_loop: AbstractEventLoop) -> None: """Test updating the config based on the raw request object."""