From 6d6ffc63db1f5dfba78aa54867ef40ff0d955fab Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 17 Nov 2023 10:34:05 -0500 Subject: [PATCH 1/2] x --- langserve/pydantic_v1.py | 9 ++++--- langserve/server.py | 52 ++++++++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/langserve/pydantic_v1.py b/langserve/pydantic_v1.py index b7cfc828..35bfd975 100644 --- a/langserve/pydantic_v1.py +++ b/langserve/pydantic_v1.py @@ -13,13 +13,16 @@ try: # F401: imported but unused - from pydantic.v1 import BaseModel, Field, ValidationError # noqa: F401 + from pydantic.v1 import BaseModel, Field, ValidationError, create_model # noqa: F401 except ImportError: - from pydantic import BaseModel, Field, ValidationError # noqa: F401 + from pydantic import BaseModel, Field, ValidationError, create_model # noqa: F401 # This is not a pydantic v1 thing, but it feels too small to create a new module for. + +PYDANTIC_VERSION = metadata.version("pydantic") + try: - _PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0]) + _PYDANTIC_MAJOR_VERSION: int = int(PYDANTIC_VERSION.split(".")[0]) except metadata.PackageNotFoundError: _PYDANTIC_MAJOR_VERSION = -1 diff --git a/langserve/server.py b/langserve/server.py index cfb2a7d7..ad8689f4 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -37,15 +37,12 @@ from langsmith.utils import tracing_is_enabled from typing_extensions import Annotated -try: - from pydantic.v1 import BaseModel, Field, ValidationError, create_model -except ImportError: - from pydantic import BaseModel, Field, ValidationError, create_model +from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict from langserve.lzstring import LZString from langserve.playground import serve_playground -from langserve.pydantic_v1 import _PYDANTIC_MAJOR_VERSION +from langserve.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, PYDANTIC_VERSION from langserve.schema import ( BatchResponseMetadata, CustomUserType, @@ -306,11 +303,15 @@ def orange(text: str) -> str: if _PYDANTIC_MAJOR_VERSION == 2: print() - print(f'{orange("OpenAPI Docs:")} ', end="") + print(f'{orange("LANGSERVE:")} ', end="") print( - "Running with pydantic >= 2: OpenAPI docs for " - "invoke/batch/stream/stream_log` endpoints will not be " - "generated; but, API endpoints and playground will work as expected." + f"⚠️ Using pydantic {PYDANTIC_VERSION}. " + f"OpenAPI docs for invoke, batch, stream, stream_log " + f"endpoints will not be generated. API endpoints and playground " + f"should work as expected. " + f"If you need to see the docs, you can downgrade to pydantic 1. " + "For example, `pip install pydantic==1.10.13`. " + f"See https://github.com/tiangolo/fastapi/issues/10360 for details." ) print() @@ -525,11 +526,38 @@ def _route_name_with_config(name: str) -> str: if hasattr(app, "openapi_tags") and (path or (app not in _APP_SEEN)): if not path: _APP_SEEN.add(app) + + if _PYDANTIC_MAJOR_VERSION == 1: + # Documentation for the default endpoints + default_endpoint_tags = { + "name": route_tags[0] if route_tags else "default", + } + elif _PYDANTIC_MAJOR_VERSION == 2: + # When using pydantic v2, we cannot generate openapi docs for + # the invoke/batch/stream/stream_log endpoints since the underlying + # models are from the pydantic.v1 namespace and cannot be supported + # by fastapi's. + # https://github.com/tiangolo/fastapi/issues/10360 + default_endpoint_tags = { + "name": route_tags[0] if route_tags else "default", + "description": ( + f"⚠️ Using pydantic {PYDANTIC_VERSION}. " + f"OpenAPI docs for `invoke`, `batch`, `stream`, `stream_log` " + f"endpoints will not be generated. API endpoints and playground " + f"should work as expected. " + f"If you need to see the docs, you can downgrade to pydantic 1. " + "For example, `pip install pydantic==1.10.13`" + f"See https://github.com/tiangolo/fastapi/issues/10360 for details." + ), + } + else: + raise AssertionError( + f"Expected pydantic major version 1 or 2, got {_PYDANTIC_MAJOR_VERSION}" + ) + app.openapi_tags = [ *(getattr(app, "openapi_tags", []) or []), - { - "name": route_tags[0] if route_tags else "default", - }, + default_endpoint_tags, { "name": route_tags_with_config[0], "description": ( From f9600fe470a6ee23b76593ffe93be70a44182a44 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 17 Nov 2023 10:34:30 -0500 Subject: [PATCH 2/2] x --- langserve/pydantic_v1.py | 7 ++++++- langserve/server.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/langserve/pydantic_v1.py b/langserve/pydantic_v1.py index 35bfd975..b31ca823 100644 --- a/langserve/pydantic_v1.py +++ b/langserve/pydantic_v1.py @@ -13,7 +13,12 @@ try: # F401: imported but unused - from pydantic.v1 import BaseModel, Field, ValidationError, create_model # noqa: F401 + from pydantic.v1 import ( # noqa: F401 + BaseModel, + Field, + ValidationError, + create_model, + ) except ImportError: from pydantic import BaseModel, Field, ValidationError, create_model # noqa: F401 diff --git a/langserve/server.py b/langserve/server.py index ad8689f4..7d9a4407 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -37,12 +37,17 @@ from langsmith.utils import tracing_is_enabled from typing_extensions import Annotated -from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model - from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict from langserve.lzstring import LZString from langserve.playground import serve_playground -from langserve.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, PYDANTIC_VERSION +from langserve.pydantic_v1 import ( + _PYDANTIC_MAJOR_VERSION, + PYDANTIC_VERSION, + BaseModel, + Field, + ValidationError, + create_model, +) from langserve.schema import ( BatchResponseMetadata, CustomUserType,