Skip to content

Commit

Permalink
Only run per request modifier for invoke,batch,stream,astream_log (#341)
Browse files Browse the repository at this point in the history
This only runs the per request modifier for the endpoints where one
would expect it to be run.

If there is another use case, please file an issue explaining it and
we'll accommodate it.

#325
  • Loading branch information
eyurtsev authored Dec 20, 2023
1 parent 08b12b3 commit b3e728d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 7 deletions.
28 changes: 24 additions & 4 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,12 @@ async def input_schema(
config_keys=self._config_keys,
model=self._ConfigPayload,
request=request,
per_req_config_modifier=self._per_req_config_modifier,
# Do not use per request config modifier for output schema
# since it's unclear why it would make sense to modify
# this using a per request config modifier.
# If this is needed, for some reason please file an issue explaining
# the user case.
per_req_config_modifier=None,
server_config=server_config,
)
config = _update_config_with_defaults(
Expand All @@ -1101,7 +1106,12 @@ async def output_schema(
config_keys=self._config_keys,
model=self._ConfigPayload,
request=request,
per_req_config_modifier=self._per_req_config_modifier,
# Do not use per request config modifier for output schema
# since it's unclear why it would make sense to modify
# this using a per request config modifier.
# If this is needed, for some reason please file an issue explaining
# the user case.
per_req_config_modifier=None,
server_config=server_config,
)
config = _update_config_with_defaults(
Expand All @@ -1123,7 +1133,12 @@ async def config_schema(
config_keys=self._config_keys,
model=self._ConfigPayload,
request=request,
per_req_config_modifier=self._per_req_config_modifier,
# Do not use per request config modifier for output schema
# since it's unclear why it would make sense to modify
# this using a per request config modifier.
# If this is needed, for some reason please file an issue explaining
# the user case.
per_req_config_modifier=None,
server_config=server_config,
)
config = _update_config_with_defaults(
Expand All @@ -1150,7 +1165,12 @@ async def playground(
config_keys=self._config_keys,
model=self._ConfigPayload,
request=request,
per_req_config_modifier=self._per_req_config_modifier,
# Do not use per request config modifier for output schema
# since it's unclear why it would make sense to modify
# this using a per request config modifier.
# If this is needed, for some reason please file an issue explaining
# the user case.
per_req_config_modifier=None,
server_config=server_config,
)

Expand Down
3 changes: 3 additions & 0 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def add_routes(
for example, if the user wants to pass in a header containing credentials
to a runnable. The RunnableConfig is presented in its dictionary form.
Note that only keys in `config_keys` will be modifiable by this function.
As of 0.0.37, this function is only called for the invoke, batch, stream,
and stream_log endpoints. This function is not called for the playground,
input_schema, output_schema, and config_schema endpoints etc.
enable_feedback_endpoint: Whether to enable an endpoint for logging feedback
to LangSmith. Enabled by default. If this flag is disabled or LangSmith
tracing is not enabled for the runnable, then 400 errors will be thrown
Expand Down
55 changes: 52 additions & 3 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,9 +1778,7 @@ async def test_enforce_trailing_slash_in_client() -> None:
assert r.url == "nosuchurl/"


async def test_per_request_config_modifier(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
async def test_per_request_config_modifier(event_loop: AbstractEventLoop) -> None:
"""Test updating the config based on the raw request object."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1810,6 +1808,57 @@ def header_passthru_modifier(
per_req_config_modifier=header_passthru_modifier,
)

# this test verifies that per request modifier is only
# applied for the expected endpoints
def buggy_modifier(config: Dict[str, Any], request: Request) -> Dict[str, Any]:
"""Update the config"""
raise ValueError("oops I did it again")

add_routes(
app,
server_runnable,
path="/with_buggy_modifier",
per_req_config_modifier=buggy_modifier,
)

async with get_async_test_client(
app,
raise_app_exceptions=False,
) as async_client:
endpoints_to_test = (
"invoke",
"batch",
"stream",
"stream_log",
"input_schema",
"output_schema",
"config_schema",
"playground/index.html",
)

for endpoint in endpoints_to_test:
url = "/with_buggy_modifier/" + endpoint

if endpoint == "batch":
payload = {"inputs": [1, 2]}
response = await async_client.post(url, json=payload)
elif endpoint in {"invoke", "stream", "stream_log"}:
payload = {"input": 1}
response = await async_client.post(url, json=payload)
elif endpoint in {"input_schema", "output_schema", "config_schema"}:
response = await async_client.get(url)
elif endpoint == "playground/index.html":
response = await async_client.get(url)
else:
raise ValueError(f"Unknown endpoint {endpoint}")

if endpoint in {"invoke", "batch"}:
assert response.status_code == 500
elif endpoint in {"stream", "stream_log"}:
assert '"status_code": 500' in response.text
else:
assert response.status_code != 500


async def test_uuid_serialization(event_loop: AbstractEventLoop) -> None:
"""Test updating the config based on the raw request object."""
Expand Down

0 comments on commit b3e728d

Please sign in to comment.