From 588e7975c2a47f85142a7a787fd34f03d74c5430 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 14 Mar 2024 13:19:37 -0400 Subject: [PATCH 1/3] x --- langserve/api_handler.py | 7 +++++-- langserve/playground.py | 28 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 20576b11..c6ddca40 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -36,7 +36,7 @@ from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict from langserve.lzstring import LZString -from langserve.playground import serve_playground +from langserve.playground import PlaygroundConfig, serve_playground from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model from langserve.schema import ( BatchResponseMetadata, @@ -468,6 +468,7 @@ def __init__( per_req_config_modifier: Optional[PerRequestConfigModifier] = None, stream_log_name_allow_list: Optional[Sequence[str]] = None, playground_type: Literal["default", "chat"] = "default", + playground_config: Optional[PlaygroundConfig] = None, ) -> None: """Create an API handler for the given runnable. @@ -561,6 +562,7 @@ def __init__( self._enable_feedback_endpoint = enable_feedback_endpoint self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint self._names_in_stream_allow_list = stream_log_name_allow_list + self._playground_config = playground_config # Client is patched using mock.patch, if changing the names # remember to make relevant updates in the unit tests. @@ -1366,7 +1368,8 @@ async def playground( file_path, feedback_enabled, public_trace_link_enabled, - playground_type=self.playground_type, + self.playground_type, + playground_config=self._playground_config, ) async def create_feedback( diff --git a/langserve/playground.py b/langserve/playground.py index eddff09a..6030b2d3 100644 --- a/langserve/playground.py +++ b/langserve/playground.py @@ -2,10 +2,12 @@ import mimetypes import os from string import Template -from typing import Literal, Sequence, Type +from typing import Literal, Optional, Sequence, Type, Union from fastapi.responses import Response +from fastapi.security import APIKeyCookie, APIKeyHeader, APIKeyQuery from langchain.schema.runnable import Runnable +from typing_extensions import TypedDict from langserve.pydantic_v1 import BaseModel @@ -47,6 +49,15 @@ def _get_mimetype(path: str) -> str: return mime_type +SupportedSecurityScheme = Union[APIKeyHeader, APIKeyQuery, APIKeyCookie] + + +class PlaygroundConfig(TypedDict, total=False): + """Configuration for the playground.""" + + security_scheme: Optional[SupportedSecurityScheme] + + async def serve_playground( runnable: Runnable, input_schema: Type[BaseModel], @@ -56,8 +67,20 @@ async def serve_playground( feedback_enabled: bool, public_trace_link_enabled: bool, playground_type: Literal["default", "chat"], + *, + playground_config: Optional[PlaygroundConfig] = None, ) -> Response: """Serve the playground.""" + security_scheme = ( + playground_config.get("security_scheme") if playground_config else None + ) + if not isinstance( + security_scheme, (APIKeyHeader, APIKeyQuery, APIKeyCookie, type(None)) + ): + raise NotImplementedError( + "Only APIKeyHeader, APIKeyQuery, APIKeyCookie, and None are supported." + ) + if playground_type == "default": path_to_dist = "./playground/dist" elif playground_type == "chat": @@ -98,6 +121,9 @@ async def serve_playground( LANGSERVE_PUBLIC_TRACE_LINK_ENABLED=json.dumps( "true" if public_trace_link_enabled else "false" ), + SECURITY_SCHEME=security_scheme.model.json() + if security_scheme + else json.dumps({}), ) else: response = f.buffer.read() From 1ad5619598d3b64a8306bf682395c28847a4c1b1 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 14 Mar 2024 13:30:15 -0400 Subject: [PATCH 2/3] x --- examples/auth/global_deps/server.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/auth/global_deps/server.py b/examples/auth/global_deps/server.py index f9379e2e..86a3ebb1 100755 --- a/examples/auth/global_deps/server.py +++ b/examples/auth/global_deps/server.py @@ -13,14 +13,16 @@ * https://fastapi.tiangolo.com/tutorial/security/ """ -from fastapi import Depends, FastAPI, Header, HTTPException +from fastapi import Depends, FastAPI, HTTPException +from fastapi.security import APIKeyHeader from langchain_core.runnables import RunnableLambda -from typing_extensions import Annotated from langserve import add_routes +XToken = APIKeyHeader(name="x-token") -async def verify_token(x_token: Annotated[str, Header()]) -> None: + +async def verify_token(x_token: str = Depends(XToken)) -> None: """Verify the token is valid.""" # Replace this with your actual authentication logic if x_token != "secret-token": From 1bdcb70e73f947cb9ff7ce314fa3ccfaab0aec84 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 14 Mar 2024 13:30:28 -0400 Subject: [PATCH 3/3] x --- examples/auth/path_dependencies/server.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/auth/path_dependencies/server.py b/examples/auth/path_dependencies/server.py index 8e3f8fc0..81e14edf 100755 --- a/examples/auth/path_dependencies/server.py +++ b/examples/auth/path_dependencies/server.py @@ -13,14 +13,16 @@ * https://fastapi.tiangolo.com/tutorial/security/ """ # noqa: E501 -from fastapi import Depends, FastAPI, Header, HTTPException +from fastapi import Depends, FastAPI, HTTPException +from fastapi.security import APIKeyHeader from langchain_core.runnables import RunnableLambda -from typing_extensions import Annotated from langserve import add_routes +XToken = APIKeyHeader(name="x-token") -async def verify_token(x_token: Annotated[str, Header()]) -> None: + +async def verify_token(x_token: str = Depends(XToken)) -> None: """Verify the token is valid.""" # Replace this with your actual authentication logic if x_token != "secret-token":