Skip to content

Commit

Permalink
Add filtering for LangServe by run name through allow list (#291)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
akira and eyurtsev authored Dec 6, 2023
1 parent 2ae30fb commit e9917a3
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 19 deletions.
33 changes: 20 additions & 13 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(
include_callback_events: bool = False,
enable_feedback_endpoint: bool = False,
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
) -> None:
"""Create a new RunnableServer.
Expand Down Expand Up @@ -444,6 +445,7 @@ def __init__(
self.base_url = base_url
self.well_known_lc_serializer = WellKnownLCSerializer()
self.enable_feedback_endpoint = enable_feedback_endpoint
self.stream_log_name_allow_list = stream_log_name_allow_list

# Please do not change the naming on ls_client. It is used with mocking
# in our unit tests for langsmith integrations.
Expand Down Expand Up @@ -858,20 +860,25 @@ async def _stream_log() -> AsyncIterator[dict]:
raise AssertionError(
f"Expected a RunLog instance got {type(chunk)}"
)
data = {
"ops": chunk.ops,
}
if (
self.stream_log_name_allow_list is None
or self.runnable.config.get("run_name")
in self.stream_log_name_allow_list
):
data = {
"ops": chunk.ops,
}

# Temporary adapter
yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
"data": self.well_known_lc_serializer.dumps(data).decode(
"utf-8"
),
"event": "data",
}
# Temporary adapter
yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
"data": self.well_known_lc_serializer.dumps(data).decode(
"utf-8"
),
"event": "data",
}
yield {"event": "end"}
except BaseException:
yield {
Expand Down
16 changes: 10 additions & 6 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def _setup_global_app_handlers(app: Union[FastAPI, APIRouter]) -> None:
@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501

Expand Down Expand Up @@ -218,8 +218,9 @@ def add_routes(
include_callback_events: bool = False,
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
enable_feedback_endpoint: bool = False,
enabled_endpoints: Optional[Sequence[EndpointName]] = None,
disabled_endpoints: Optional[Sequence[EndpointName]] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
enabled_endpoints: Optional[Sequence[EndpointName]] = None,
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
Expand Down Expand Up @@ -308,6 +309,9 @@ def add_routes(
disabled_endpoints=["playground"],
)
```
stream_log_name_allow_list: list of run names that the client can
stream as intermediate steps
"""
endpoint_configuration = _EndpointConfiguration(
enabled_endpoints=enabled_endpoints,
Expand Down Expand Up @@ -348,8 +352,8 @@ def add_routes(
include_callback_events=include_callback_events,
enable_feedback_endpoint=enable_feedback_endpoint,
per_req_config_modifier=per_req_config_modifier,
stream_log_name_allow_list=stream_log_name_allow_list,
)

namespace = path or ""

route_tags = [path.strip("/")] if path else None
Expand Down
84 changes: 84 additions & 0 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,90 @@ def add_one(x: int) -> int:
}


@pytest.mark.asyncio
async def test_astream_log_allowlist(event_loop: AbstractEventLoop) -> None:
"""Test async stream with an allowlist."""

async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1

app = FastAPI()
add_routes(
app,
RunnableLambda(add_one).with_config({"run_name": "allowed"}),
path="/empty_allowlist",
input_type=int,
stream_log_name_allow_list=[],
)
add_routes(
app,
RunnableLambda(add_one).with_config({"run_name": "allowed"}),
input_type=int,
path="/allowlist",
stream_log_name_allow_list=["allowed"],
)

# Invoke request
async with get_async_remote_runnable(app, path="/empty_allowlist/") as runnable:
run_log_patches = []

async for chunk in runnable.astream_log(1):
run_log_patches.append(chunk)

assert len(run_log_patches) == 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_tags=[]):
run_log_patches.append(chunk)

assert len(run_log_patches) == 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_types=[]):
run_log_patches.append(chunk)

assert len(run_log_patches) == 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_names=[]):
run_log_patches.append(chunk)

assert len(run_log_patches) == 0

async with get_async_remote_runnable(app, path="/allowlist/") as runnable:
run_log_patches = []

async for chunk in runnable.astream_log(1):
run_log_patches.append(chunk)

assert len(run_log_patches) > 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_tags=[]):
run_log_patches.append(chunk)

assert len(run_log_patches) > 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_types=[]):
run_log_patches.append(chunk)

assert len(run_log_patches) > 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_names=[]):
run_log_patches.append(chunk)

assert len(run_log_patches) > 0

run_log_patches = []
async for chunk in runnable.astream_log(1, include_names=["allowed"]):
run_log_patches.append(chunk)

assert len(run_log_patches) > 0


def test_invoke_as_part_of_sequence(sync_remote_runnable: RemoteRunnable) -> None:
"""Test as part of sequence."""
runnable = sync_remote_runnable | RunnableLambda(func=lambda x: x + 1)
Expand Down

0 comments on commit e9917a3

Please sign in to comment.