From f691053c32c0148c6b321a57a72d493948b1e23d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 13 Nov 2023 12:10:34 -0500 Subject: [PATCH] x --- langserve/server.py | 12 ++++++++++-- tests/unit_tests/test_server_client.py | 17 ++++++++++++++--- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/langserve/server.py b/langserve/server.py index aec218fb..6c88b5bf 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -114,6 +114,13 @@ def _unpack_request_config( else: raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}") config = merge_configs(*config_dicts) + if "configurable" in config and config["configurable"]: + if "configurable" not in keys: + raise HTTPException( + 422, + "Server code has modified the default accepted config keys to " + "not accept `configurable`. ", + ) projected_config = {k: config[k] for k in keys if k in config} return ( per_req_config_modifier(projected_config, request) @@ -405,7 +412,7 @@ def add_routes( path: str = "", input_type: Union[Type, Literal["auto"], BaseModel] = "auto", output_type: Union[Type, Literal["auto"], BaseModel] = "auto", - config_keys: Sequence[str] = (), + config_keys: Sequence[str] = ("configurable",), include_callback_events: bool = False, enable_feedback_endpoint: bool = False, per_req_config_modifier: Optional[PerRequestConfigModifier] = None, @@ -438,7 +445,8 @@ def add_routes( Favor using runnable.with_types(input_type=..., output_type=...) instead. This parameter may get deprecated! config_keys: list of config keys that will be accepted, by default - no config keys are accepted. + will accept `configurable` key in the config. Will only be used + if the runnable is configurable. include_callback_events: Whether to include callback events in the response. If true, the client will be able to show trace information including events that occurred on the server side. diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 8f4281f4..3b943698 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -1084,7 +1084,7 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: assert chain.invoke({"name": "cat"}) == "say cat" app = FastAPI() - add_routes(app, chain, config_keys=["tags", "configurable"]) + add_routes(app, chain) async with get_async_remote_runnable(app) as remote_runnable: # Test with hard-coded LLM @@ -1094,7 +1094,7 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: assert ( await remote_runnable.ainvoke( {"name": "foo"}, - {"configurable": {"template": "hear {name}"}, "tags": ["h"]}, + {"configurable": {"template": "hear {name}"}}, ) == "hear foo" ) @@ -1102,11 +1102,22 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: assert ( await remote_runnable.ainvoke( {"name": "foo"}, - {"configurable": {"llm": "hardcoded_llm"}, "tags": ["h"]}, + {"configurable": {"llm": "hardcoded_llm"}}, ) == "hello Mr. Kitten!" ) + add_routes(app, chain, path="/no_config", config_keys=["tags"]) + + async with get_async_remote_runnable(app, path="/no_config") as remote_runnable: + with pytest.raises(httpx.HTTPError) as cb: + await remote_runnable.ainvoke( + {"name": "foo"}, + {"configurable": {"template": "hear {name}"}}, + ) + + assert cb.value.response.status_code == 422 + # Test for utilities