From ea7bb5d9707bdadc8bebd8ffc7a352cb88cd1118 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 13 Nov 2023 13:58:02 -0500 Subject: [PATCH 1/3] Minor: Update names to not include `Test` prefix (#222) This triggers a warning from pytest collection when running tests since the namespaces begin with `Test` --- tests/unit_tests/test_serialization.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit_tests/test_serialization.py b/tests/unit_tests/test_serialization.py index 4ae4ab04..4a77ce13 100644 --- a/tests/unit_tests/test_serialization.py +++ b/tests/unit_tests/test_serialization.py @@ -123,14 +123,14 @@ def test_decode_events(data: Any, expected: Any) -> None: assert load_events(data) == expected -class TestEnum(Enum): +class SimpleEnum(Enum): a = "a" b = "b" -class TestModel(BaseModel): +class SimpleModel(BaseModel): x: int - y: TestEnum + y: SimpleEnum z: uuid.UUID dt: datetime.datetime d: datetime.date @@ -148,11 +148,11 @@ class TestModel(BaseModel): (datetime.date(2020, 1, 1), "2020-01-01"), (datetime.time(0, 0, 0), "00:00:00"), (datetime.time(0, 0, 0, 1), "00:00:00.000001"), - (TestEnum.a, "a"), + (SimpleEnum.a, "a"), ( - TestModel( + SimpleModel( x=1, - y=TestEnum.a, + y=SimpleEnum.a, z=uuid.UUID(int=1), dt=datetime.datetime(2020, 1, 1), d=datetime.date(2020, 1, 1), From d591be4dce068ea4a87f4dfc052f1721853bd428 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 13 Nov 2023 14:26:38 -0500 Subject: [PATCH 2/3] Enable configurable by default langserve side (#219) * Enable configurable key by default * Should be safe to enable by default since the user is already marking the runnable as configurable (when they make it configurable) * Raise an client side error if the client specifies a `configurable` key but the server does not accept it (will help in debugging) --- langserve/server.py | 31 +++++++++++++++++--------- tests/unit_tests/test_server_client.py | 17 +++++++++++--- tests/unit_tests/test_validation.py | 4 ++-- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/langserve/server.py b/langserve/server.py index aec218fb..9bd37fce 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -97,7 +97,7 @@ def _config_from_hash(config_hash: str) -> Dict[str, Any]: def _unpack_request_config( *configs: Union[BaseModel, Mapping, str], - keys: Sequence[str], + config_keys: Sequence[str], model: Type[BaseModel], request: Request, per_req_config_modifier: Optional[PerRequestConfigModifier], @@ -114,7 +114,15 @@ def _unpack_request_config( else: raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}") config = merge_configs(*config_dicts) - projected_config = {k: config[k] for k in keys if k in config} + if "configurable" in config and config["configurable"]: + if "configurable" not in config_keys: + raise HTTPException( + 422, + "The config field `configurable` has been disallowed by the server. " + "This can be modified server side by adding `configurable` to the list " + "of `config_keys` argument in `add_routes`", + ) + projected_config = {k: config[k] for k in config_keys if k in config} return ( per_req_config_modifier(projected_config, request) if per_req_config_modifier @@ -405,7 +413,7 @@ def add_routes( path: str = "", input_type: Union[Type, Literal["auto"], BaseModel] = "auto", output_type: Union[Type, Literal["auto"], BaseModel] = "auto", - config_keys: Sequence[str] = (), + config_keys: Sequence[str] = ("configurable",), include_callback_events: bool = False, enable_feedback_endpoint: bool = False, per_req_config_modifier: Optional[PerRequestConfigModifier] = None, @@ -438,7 +446,8 @@ def add_routes( Favor using runnable.with_types(input_type=..., output_type=...) instead. This parameter may get deprecated! config_keys: list of config keys that will be accepted, by default - no config keys are accepted. + will accept `configurable` key in the config. Will only be used + if the runnable is configurable. include_callback_events: Whether to include callback events in the response. If true, the client will be able to show trace information including events that occurred on the server side. @@ -569,7 +578,7 @@ async def _get_config_and_input( config = _unpack_request_config( config_hash, body.config, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, @@ -654,7 +663,7 @@ async def batch( _unpack_request_config( config_hash, config, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, @@ -665,7 +674,7 @@ async def batch( configs = _unpack_request_config( config_hash, config, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, @@ -935,7 +944,7 @@ async def input_schema(request: Request, config_hash: str = "") -> Any: with _with_validation_error_translation(): config = _unpack_request_config( config_hash, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, @@ -958,7 +967,7 @@ async def output_schema(request: Request, config_hash: str = "") -> Any: with _with_validation_error_translation(): config = _unpack_request_config( config_hash, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, @@ -978,7 +987,7 @@ async def config_schema(request: Request, config_hash: str = "") -> Any: with _with_validation_error_translation(): config = _unpack_request_config( config_hash, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, @@ -997,7 +1006,7 @@ async def playground( with _with_validation_error_translation(): config = _unpack_request_config( config_hash, - keys=config_keys, + config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 8f4281f4..3b943698 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -1084,7 +1084,7 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: assert chain.invoke({"name": "cat"}) == "say cat" app = FastAPI() - add_routes(app, chain, config_keys=["tags", "configurable"]) + add_routes(app, chain) async with get_async_remote_runnable(app) as remote_runnable: # Test with hard-coded LLM @@ -1094,7 +1094,7 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: assert ( await remote_runnable.ainvoke( {"name": "foo"}, - {"configurable": {"template": "hear {name}"}, "tags": ["h"]}, + {"configurable": {"template": "hear {name}"}}, ) == "hear foo" ) @@ -1102,11 +1102,22 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: assert ( await remote_runnable.ainvoke( {"name": "foo"}, - {"configurable": {"llm": "hardcoded_llm"}, "tags": ["h"]}, + {"configurable": {"llm": "hardcoded_llm"}}, ) == "hello Mr. Kitten!" ) + add_routes(app, chain, path="/no_config", config_keys=["tags"]) + + async with get_async_remote_runnable(app, path="/no_config") as remote_runnable: + with pytest.raises(httpx.HTTPError) as cb: + await remote_runnable.ainvoke( + {"name": "foo"}, + {"configurable": {"template": "hear {name}"}}, + ) + + assert cb.value.response.status_code == 422 + # Test for utilities diff --git a/tests/unit_tests/test_validation.py b/tests/unit_tests/test_validation.py index d63f3810..c6442e6a 100644 --- a/tests/unit_tests/test_validation.py +++ b/tests/unit_tests/test_validation.py @@ -157,7 +157,7 @@ def test_invoke_request_with_runnables() -> None: Model( input={"name": "bob"}, ).config, - keys=[], + config_keys=[], model=config, request=MagicMock(Request), per_req_config_modifier=lambda x, y: x, @@ -184,7 +184,7 @@ def test_invoke_request_with_runnables() -> None: assert _unpack_request_config( request.config, - keys=["configurable"], + config_keys=["configurable"], model=config, request=MagicMock(Request), per_req_config_modifier=lambda x, y: x, From 38b935644fce6807cc5443f8b1e9f8eed124eef0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 13 Nov 2023 14:26:47 -0500 Subject: [PATCH 3/3] Apply noqa to specific statement rather than entire file (#223) Apply noqa to specific statement rather than entire file --- langserve/server.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/langserve/server.py b/langserve/server.py index 9bd37fce..d147f010 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -275,7 +275,6 @@ def _scrub_exceptions_in_event(event: CallbackEventDict) -> CallbackEventDict: def _setup_global_app_handlers(app: Union[FastAPI, APIRouter]) -> None: @app.on_event("startup") async def startup_event(): - # ruff: noqa: E501 LANGSERVE = """ __ ___ .__ __. _______ _______. _______ .______ ____ ____ _______ | | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____| @@ -283,7 +282,7 @@ async def startup_event(): | | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __| | `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____ |_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______| -""" +""" # noqa: E501 def green(text): return "\x1b[1;32;40m" + text + "\x1b[0m" @@ -292,7 +291,8 @@ def green(text): print(LANGSERVE) for path in paths: print( - f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" is live at:' + f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" is ' + f'live at:' ) print(f'{green("LANGSERVE:")} │') print(f'{green("LANGSERVE:")} └──> {path}/playground/') @@ -622,8 +622,8 @@ async def invoke( return _json_encode_response( InvokeResponse( output=well_known_lc_serializer.dumpd(output), - # Callbacks are scrubbed and exceptions are converted to serializable format - # before returned in the response. + # Callbacks are scrubbed and exceptions are converted to + # serializable format before returned in the response. callback_events=callback_events, metadata=SingletonResponseMetadata( run_id=_get_base_run_id_as_str(event_aggregator)