Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more run time warnings #753

Merged
merged 7 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ async def _get_config_and_input(
# This takes into account changes in the input type when
# using configuration.
schema = self._runnable.with_config(config).input_schema
input_ = schema.validate(body.input)
input_ = schema.model_validate(body.input)
return config, _unpack_input(input_)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
Expand Down Expand Up @@ -892,7 +892,7 @@ async def batch(
raise RequestValidationError(errors=["Invalid JSON body"])

with _with_validation_error_translation():
body = BatchRequestShallowValidator.validate(body)
body = BatchRequestShallowValidator.model_validate(body)
config = body.config

# First unpack the config
Expand Down Expand Up @@ -943,7 +943,7 @@ async def batch(

inputs = [
_unpack_input(
self._runnable.with_config(config_).input_schema.validate(input_)
self._runnable.with_config(config_).input_schema.model_validate(input_)
)
for config_, input_ in zip(configs_, inputs_)
]
Expand Down
73 changes: 42 additions & 31 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
The main entry point is the `add_routes` function which adds the routes to an existing
FastAPI app or APIRouter.
"""
import warnings
import weakref
from typing import (
Any,
Expand Down Expand Up @@ -201,37 +202,47 @@ def _register_path_for_app(
def _setup_global_app_handlers(
app: Union[FastAPI, APIRouter], endpoint_configuration: _EndpointConfiguration
) -> None:
@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501

def green(text: str) -> str:
"""Return the given text in green."""
return "\x1b[1;32;40m" + text + "\x1b[0m"

def orange(text: str) -> str:
"""Return the given text in orange."""
return "\x1b[1;31;40m" + text + "\x1b[0m"

paths = _APP_TO_PATHS[app]
print(LANGSERVE)
for path in paths:
if endpoint_configuration.is_playground_enabled:
print(
f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" is '
f"live at:"
)
print(f'{green("LANGSERVE:")} │')
print(f'{green("LANGSERVE:")} └──> {path}/playground/')
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/')
with warnings.catch_warnings():
# We are using deprecated functionality here.
# This code should be re-written to simply construct a pydantic model
# using inspect.signature and create_model.
warnings.filterwarnings(
"ignore",
"[\\s.]*on_event is deprecated[\\s.]*",
category=DeprecationWarning,
)

@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501

def green(text: str) -> str:
"""Return the given text in green."""
return "\x1b[1;32;40m" + text + "\x1b[0m"

def orange(text: str) -> str:
"""Return the given text in orange."""
return "\x1b[1;31;40m" + text + "\x1b[0m"

paths = _APP_TO_PATHS[app]
print(LANGSERVE)
for path in paths:
if endpoint_configuration.is_playground_enabled:
print(
f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" '
f'is live at:'
)
print(f'{green("LANGSERVE:")} │')
print(f'{green("LANGSERVE:")} └──> {path}/playground/')
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/')


# PUBLIC API
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv"
# take more than 5 seconds
timeout = 5
asyncio_mode = "auto"
filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
]

43 changes: 19 additions & 24 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import sys
import uuid
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -123,7 +122,7 @@ def _replace_run_id_in_stream_resp(streamed_resp: str) -> str:
return streamed_resp.replace(uuid, "<REPLACED>")


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop()
Expand All @@ -134,7 +133,7 @@ def event_loop():


@pytest.fixture()
def app(event_loop: AbstractEventLoop) -> FastAPI:
def app() -> FastAPI:
"""A simple server that wraps a Runnable and exposes it as an API."""

async def add_one_or_passthrough(
Expand All @@ -158,7 +157,7 @@ async def add_one_or_passthrough(


@pytest.fixture()
def app_for_config(event_loop: AbstractEventLoop) -> FastAPI:
def app_for_config() -> FastAPI:
"""A simple server that wraps a Runnable and exposes it as an API."""

async def return_config(
Expand Down Expand Up @@ -223,7 +222,7 @@ async def get_async_test_client(
app=server,
raise_app_exceptions=raise_app_exceptions,
)
async_client = AsyncClient(app=server, base_url=url, transport=transport)
async_client = AsyncClient(base_url=url, transport=transport)
try:
yield async_client
finally:
Expand Down Expand Up @@ -333,7 +332,7 @@ async def test_server_async(app: FastAPI) -> None:
# test bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
# Test invoke
response = await async_client.post("/invoke", data="bad json []")
response = await async_client.post("/invoke", content="bad json []")
# Client side error bad json.
assert response.status_code == 422

Expand All @@ -353,7 +352,7 @@ async def test_server_async(app: FastAPI) -> None:
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
# Test invoke
# Test bad batch requests
response = await async_client.post("/batch", data="bad json []")
response = await async_client.post("/batch", content="bad json []")
# Client side error bad json.
assert response.status_code == 422

Expand All @@ -378,15 +377,15 @@ async def test_server_async(app: FastAPI) -> None:
# test stream bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
# Test bad stream requests
response = await async_client.post("/stream", data="bad json []")
response = await async_client.post("/stream", content="bad json []")
assert response.status_code == 422

response = await async_client.post("/stream", json={})
assert response.status_code == 422

# test stream_log bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
response = await async_client.post("/stream_log", data="bad json []")
response = await async_client.post("/stream_log", content="bad json []")
assert response.status_code == 422

response = await async_client.post("/stream_log", json={})
Expand Down Expand Up @@ -448,7 +447,7 @@ async def test_server_astream_events(app: FastAPI) -> None:

# test stream_events with bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
response = await async_client.post("/stream_events", data="bad json []")
response = await async_client.post("/stream_events", content="bad json []")
assert response.status_code == 422

response = await async_client.post("/stream_events", json={})
Expand Down Expand Up @@ -854,7 +853,7 @@ async def with_errors(inputs: dict) -> AsyncIterator[int]:
assert e.value.response.status_code == 500


async def test_astream_log_allowlist(event_loop: AbstractEventLoop) -> None:
async def test_astream_log_allowlist() -> None:
"""Test async stream with an allowlist."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1035,7 +1034,7 @@ async def test_invoke_as_part_of_sequence_async(
}


async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
async def test_multiple_runnables() -> None:
"""Test serving multiple runnables."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1159,7 +1158,7 @@ async def add_one(x: int) -> int:
await runnable.abatch(["hello"])


async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
async def test_input_validation_with_lc_types() -> None:
"""Test client side and server side exceptions."""

app = FastAPI()
Expand Down Expand Up @@ -1252,9 +1251,7 @@ async def test_async_client_close() -> None:
assert async_client.is_closed is True


async def test_openapi_docs_with_identical_runnables(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
async def test_openapi_docs_with_identical_runnables(mocker: MockerFixture) -> None:
"""Test client side and server side exceptions."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1301,7 +1298,7 @@ async def add_one(x: int) -> int:
assert response.status_code == 200


async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None:
async def test_configurable_runnables() -> None:
"""Add tests for using langchain's configurable runnables"""

template = PromptTemplate.from_template("say {name}").configurable_fields(
Expand Down Expand Up @@ -1391,7 +1388,7 @@ class Foo(BaseModel):
assert Model.__name__ == "BarFoo"


async def test_input_config_output_schemas(event_loop: AbstractEventLoop) -> None:
async def test_input_config_output_schemas() -> None:
"""Test schemas returned for different configurations."""
# TODO(Fix me): need to fix handling of global state -- we get problems
# gives inconsistent results when running multiple tests / results
Expand Down Expand Up @@ -1753,7 +1750,7 @@ async def test_server_side_error() -> None:
# assert e.response.text == "Internal Server Error"


def test_server_side_error_sync(event_loop: AbstractEventLoop) -> None:
def test_server_side_error_sync() -> None:
"""Test server side error handling."""

app = FastAPI()
Expand Down Expand Up @@ -1982,7 +1979,7 @@ async def test_enforce_trailing_slash_in_client() -> None:
assert r.url == "nosuchurl/"


async def test_per_request_config_modifier(event_loop: AbstractEventLoop) -> None:
async def test_per_request_config_modifier() -> None:
"""Test updating the config based on the raw request object."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -2025,9 +2022,7 @@ async def header_passthru_modifier(
assert response.json()["output"] == 2


async def test_per_request_config_modifier_endpoints(
event_loop: AbstractEventLoop,
) -> None:
async def test_per_request_config_modifier_endpoints() -> None:
"""Verify that per request modifier is only applied for the expected endpoints."""

# this test verifies that per request modifier is only
Expand Down Expand Up @@ -2097,7 +2092,7 @@ async def buggy_modifier(
assert response.status_code != 500


async def test_uuid_serialization(event_loop: AbstractEventLoop) -> None:
async def test_uuid_serialization() -> None:
"""Test updating the config based on the raw request object."""
import datetime

Expand Down
Loading