Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth for playground #539

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/auth/global_deps/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
* https://fastapi.tiangolo.com/tutorial/security/
"""

from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import APIKeyHeader
from langchain_core.runnables import RunnableLambda
from typing_extensions import Annotated

from langserve import add_routes

XToken = APIKeyHeader(name="x-token")

async def verify_token(x_token: Annotated[str, Header()]) -> None:

async def verify_token(x_token: str = Depends(XToken)) -> None:
"""Verify the token is valid."""
# Replace this with your actual authentication logic
if x_token != "secret-token":
Expand Down
8 changes: 5 additions & 3 deletions examples/auth/path_dependencies/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
* https://fastapi.tiangolo.com/tutorial/security/
""" # noqa: E501

from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import APIKeyHeader
from langchain_core.runnables import RunnableLambda
from typing_extensions import Annotated

from langserve import add_routes

XToken = APIKeyHeader(name="x-token")

async def verify_token(x_token: Annotated[str, Header()]) -> None:

async def verify_token(x_token: str = Depends(XToken)) -> None:
"""Verify the token is valid."""
# Replace this with your actual authentication logic
if x_token != "secret-token":
Expand Down
7 changes: 5 additions & 2 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
from langserve.lzstring import LZString
from langserve.playground import serve_playground
from langserve.playground import PlaygroundConfig, serve_playground
from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model
from langserve.schema import (
BatchResponseMetadata,
Expand Down Expand Up @@ -468,6 +468,7 @@ def __init__(
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
playground_config: Optional[PlaygroundConfig] = None,
) -> None:
"""Create an API handler for the given runnable.

Expand Down Expand Up @@ -561,6 +562,7 @@ def __init__(
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
self._playground_config = playground_config

# Client is patched using mock.patch, if changing the names
# remember to make relevant updates in the unit tests.
Expand Down Expand Up @@ -1366,7 +1368,8 @@ async def playground(
file_path,
feedback_enabled,
public_trace_link_enabled,
playground_type=self.playground_type,
self.playground_type,
playground_config=self._playground_config,
)

async def create_feedback(
Expand Down
28 changes: 27 additions & 1 deletion langserve/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import mimetypes
import os
from string import Template
from typing import Literal, Sequence, Type
from typing import Literal, Optional, Sequence, Type, Union

from fastapi.responses import Response
from fastapi.security import APIKeyCookie, APIKeyHeader, APIKeyQuery
from langchain.schema.runnable import Runnable
from typing_extensions import TypedDict

from langserve.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -47,6 +49,15 @@ def _get_mimetype(path: str) -> str:
return mime_type


SupportedSecurityScheme = Union[APIKeyHeader, APIKeyQuery, APIKeyCookie]


class PlaygroundConfig(TypedDict, total=False):
"""Configuration for the playground."""

security_scheme: Optional[SupportedSecurityScheme]


async def serve_playground(
runnable: Runnable,
input_schema: Type[BaseModel],
Expand All @@ -56,8 +67,20 @@ async def serve_playground(
feedback_enabled: bool,
public_trace_link_enabled: bool,
playground_type: Literal["default", "chat"],
*,
playground_config: Optional[PlaygroundConfig] = None,
) -> Response:
"""Serve the playground."""
security_scheme = (
playground_config.get("security_scheme") if playground_config else None
)
if not isinstance(
security_scheme, (APIKeyHeader, APIKeyQuery, APIKeyCookie, type(None))
):
raise NotImplementedError(
"Only APIKeyHeader, APIKeyQuery, APIKeyCookie, and None are supported."
)

if playground_type == "default":
path_to_dist = "./playground/dist"
elif playground_type == "chat":
Expand Down Expand Up @@ -98,6 +121,9 @@ async def serve_playground(
LANGSERVE_PUBLIC_TRACE_LINK_ENABLED=json.dumps(
"true" if public_trace_link_enabled else "false"
),
SECURITY_SCHEME=security_scheme.model.json()
if security_scheme
else json.dumps({}),
)
else:
response = f.buffer.read()
Expand Down
Loading