From e9917a35058ee099f0cc85fe488bc215a9713455 Mon Sep 17 00:00:00 2001 From: Alex Kira Date: Wed, 6 Dec 2023 10:59:57 -0800 Subject: [PATCH] Add filtering for LangServe by run name through allow list (#291) Co-authored-by: Eugene Yurtsev --- langserve/api_handler.py | 33 ++++++---- langserve/server.py | 16 +++-- tests/unit_tests/test_server_client.py | 84 ++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 19 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index db560699..ae16caed 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -389,6 +389,7 @@ def __init__( include_callback_events: bool = False, enable_feedback_endpoint: bool = False, per_req_config_modifier: Optional[PerRequestConfigModifier] = None, + stream_log_name_allow_list: Optional[Sequence[str]] = None, ) -> None: """Create a new RunnableServer. @@ -444,6 +445,7 @@ def __init__( self.base_url = base_url self.well_known_lc_serializer = WellKnownLCSerializer() self.enable_feedback_endpoint = enable_feedback_endpoint + self.stream_log_name_allow_list = stream_log_name_allow_list # Please do not change the naming on ls_client. It is used with mocking # in our unit tests for langsmith integrations. @@ -858,20 +860,25 @@ async def _stream_log() -> AsyncIterator[dict]: raise AssertionError( f"Expected a RunLog instance got {type(chunk)}" ) - data = { - "ops": chunk.ops, - } + if ( + self.stream_log_name_allow_list is None + or self.runnable.config.get("run_name") + in self.stream_log_name_allow_list + ): + data = { + "ops": chunk.ops, + } - # Temporary adapter - yield { - # EventSourceResponse expects a string for data - # so after serializing into bytes, we decode into utf-8 - # to get a string. - "data": self.well_known_lc_serializer.dumps(data).decode( - "utf-8" - ), - "event": "data", - } + # Temporary adapter + yield { + # EventSourceResponse expects a string for data + # so after serializing into bytes, we decode into utf-8 + # to get a string. + "data": self.well_known_lc_serializer.dumps(data).decode( + "utf-8" + ), + "event": "data", + } yield {"event": "end"} except BaseException: yield { diff --git a/langserve/server.py b/langserve/server.py index 8adce369..bcf61fe3 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -46,11 +46,11 @@ def _setup_global_app_handlers(app: Union[FastAPI, APIRouter]) -> None: @app.on_event("startup") async def startup_event(): LANGSERVE = r""" - __ ___ .__ __. _______ _______. _______ .______ ____ ____ _______ + __ ___ .__ __. _______ _______. _______ .______ ____ ____ _______ | | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____| -| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__ -| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __| -| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____ +| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__ +| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __| +| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____ |_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______| """ # noqa: E501 @@ -218,8 +218,9 @@ def add_routes( include_callback_events: bool = False, per_req_config_modifier: Optional[PerRequestConfigModifier] = None, enable_feedback_endpoint: bool = False, - enabled_endpoints: Optional[Sequence[EndpointName]] = None, disabled_endpoints: Optional[Sequence[EndpointName]] = None, + stream_log_name_allow_list: Optional[Sequence[str]] = None, + enabled_endpoints: Optional[Sequence[EndpointName]] = None, ) -> None: """Register the routes on the given FastAPI app or APIRouter. @@ -308,6 +309,9 @@ def add_routes( disabled_endpoints=["playground"], ) ``` + stream_log_name_allow_list: list of run names that the client can + stream as intermediate steps + """ endpoint_configuration = _EndpointConfiguration( enabled_endpoints=enabled_endpoints, @@ -348,8 +352,8 @@ def add_routes( include_callback_events=include_callback_events, enable_feedback_endpoint=enable_feedback_endpoint, per_req_config_modifier=per_req_config_modifier, + stream_log_name_allow_list=stream_log_name_allow_list, ) - namespace = path or "" route_tags = [path.strip("/")] if path else None diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 3e6b5082..2bd7c23e 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -737,6 +737,90 @@ def add_one(x: int) -> int: } +@pytest.mark.asyncio +async def test_astream_log_allowlist(event_loop: AbstractEventLoop) -> None: + """Test async stream with an allowlist.""" + + async def add_one(x: int) -> int: + """Add one to simulate a valid function""" + return x + 1 + + app = FastAPI() + add_routes( + app, + RunnableLambda(add_one).with_config({"run_name": "allowed"}), + path="/empty_allowlist", + input_type=int, + stream_log_name_allow_list=[], + ) + add_routes( + app, + RunnableLambda(add_one).with_config({"run_name": "allowed"}), + input_type=int, + path="/allowlist", + stream_log_name_allow_list=["allowed"], + ) + + # Invoke request + async with get_async_remote_runnable(app, path="/empty_allowlist/") as runnable: + run_log_patches = [] + + async for chunk in runnable.astream_log(1): + run_log_patches.append(chunk) + + assert len(run_log_patches) == 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_tags=[]): + run_log_patches.append(chunk) + + assert len(run_log_patches) == 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_types=[]): + run_log_patches.append(chunk) + + assert len(run_log_patches) == 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_names=[]): + run_log_patches.append(chunk) + + assert len(run_log_patches) == 0 + + async with get_async_remote_runnable(app, path="/allowlist/") as runnable: + run_log_patches = [] + + async for chunk in runnable.astream_log(1): + run_log_patches.append(chunk) + + assert len(run_log_patches) > 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_tags=[]): + run_log_patches.append(chunk) + + assert len(run_log_patches) > 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_types=[]): + run_log_patches.append(chunk) + + assert len(run_log_patches) > 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_names=[]): + run_log_patches.append(chunk) + + assert len(run_log_patches) > 0 + + run_log_patches = [] + async for chunk in runnable.astream_log(1, include_names=["allowed"]): + run_log_patches.append(chunk) + + assert len(run_log_patches) > 0 + + def test_invoke_as_part_of_sequence(sync_remote_runnable: RemoteRunnable) -> None: """Test as part of sequence.""" runnable = sync_remote_runnable | RunnableLambda(func=lambda x: x + 1)