Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade to pydantic 2 #744

Merged
merged 21 commits into from
Sep 9, 2024
94 changes: 0 additions & 94 deletions .github/workflows/_pydantic_compatibility.yml

This file was deleted.

7 changes: 0 additions & 7 deletions .github/workflows/langserve_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ jobs:
with:
working-directory: .
secrets: inherit

pydantic-compatibility:
uses:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: .
secrets: inherit
test:
timeout-minutes: 10
runs-on: ubuntu-latest
Expand Down
53 changes: 53 additions & 0 deletions langserve/_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Dict, Type, cast

from pydantic import BaseModel, ConfigDict, RootModel
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)


def _create_root_model(name: str, type_: Any) -> Type[RootModel]:
"""Create a base class."""

def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_

def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_

base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
# Should replace __module__ with caller based on stack frame.
"__module__": "langserve._pydantic",
}

custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[RootModel], custom_root_type)
56 changes: 26 additions & 30 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
from langsmith import client as ls_client
from langsmith.schemas import FeedbackIngestToken
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field, RootModel, ValidationError, create_model
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from typing_extensions import TypedDict

from langserve._pydantic import _create_root_model
from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
from langserve.lzstring import LZString
from langserve.playground import serve_playground
from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model
from langserve.schema import (
BatchResponseMetadata,
CustomUserType,
Expand Down Expand Up @@ -256,10 +257,12 @@ def _update_config_with_defaults(
}
metadata.update(hosted_metadata)

non_overridable_default_config = RunnableConfig(
run_name=run_name,
metadata=metadata,
)
non_overridable_default_config: RunnableConfig = {
"metadata": metadata,
}

if run_name:
non_overridable_default_config["run_name"] = run_name

# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
Expand All @@ -280,8 +283,8 @@ def _update_config_with_defaults(

def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
model = validated_model.__root__
if isinstance(validated_model, RootModel):
model = validated_model.root
else:
model = validated_model

Expand All @@ -305,7 +308,7 @@ def _rename_pydantic_model(model: Type[BaseModel], prefix: str) -> Type[BaseMode
"""Rename the given pydantic model to the given name."""
return create_model(
prefix + model.__name__,
__config__=model.__config__,
__config__=model.model_config,
**{
fieldname: (
_rename_pydantic_model(field.annotation, prefix)
Expand All @@ -314,10 +317,10 @@ def _rename_pydantic_model(model: Type[BaseModel], prefix: str) -> Type[BaseMode
Field(
field.default,
title=fieldname,
description=field.field_info.description,
description=field.description,
),
)
for fieldname, field in model.__fields__.items()
for fieldname, field in model.model_fields.items()
},
)

Expand All @@ -334,7 +337,7 @@ def _resolve_model(
if isclass(type_) and issubclass(type_, BaseModel):
model = type_
else:
model = create_model(default_name, __root__=(type_, ...))
model = _create_root_model(default_name, type_)

hash_ = model.schema_json()

Expand Down Expand Up @@ -367,11 +370,7 @@ def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[Base
A new model with name prepended with the given namespace.
"""
model_with_unique_name = _rename_pydantic_model(model, namespace)
if "run_id" in model_with_unique_name.__annotations__:
# Help resolve reference by providing namespace references
model_with_unique_name.update_forward_refs(uuid=uuid)
else:
model_with_unique_name.update_forward_refs()
model_with_unique_name.model_rebuild()
return model_with_unique_name


Expand Down Expand Up @@ -404,7 +403,7 @@ def _with_validation_error_translation() -> Generator[None, None, None]:
try:
yield
except ValidationError as e:
raise RequestValidationError(e.errors(), body=e.model)
raise RequestValidationError(e.errors())


def _json_encode_response(model: BaseModel) -> JSONResponse:
Expand All @@ -424,39 +423,36 @@ def _json_encode_response(model: BaseModel) -> JSONResponse:

if isinstance(model, InvokeBaseResponse):
# Invoke Response
# Collapse '__root__' from output field if it exists. This is done
# Collapse 'root' from output field if it exists. This is done
# automatically by fastapi when annotating request and response with
# We need to do this manually since we're using vanilla JSONResponse
if isinstance(obj["output"], dict) and "__root__" in obj["output"]:
obj["output"] = obj["output"]["__root__"]
if isinstance(obj["output"], dict) and "root" in obj["output"]:
obj["output"] = obj["output"]["root"]

if "callback_events" in obj:
for idx, callback_event in enumerate(obj["callback_events"]):
if isinstance(callback_event, dict) and "__root__" in callback_event:
obj["callback_events"][idx] = callback_event["__root__"]
if isinstance(callback_event, dict) and "root" in callback_event:
obj["callback_events"][idx] = callback_event["root"]
elif isinstance(model, BatchBaseResponse):
if not isinstance(obj["output"], list):
raise AssertionError("Expected output to be a list")

# Collapse '__root__' from output field if it exists. This is done
# Collapse 'root' from output field if it exists. This is done
# automatically by fastapi when annotating request and response with
# We need to do this manually since we're using vanilla JSONResponse
outputs = obj["output"]
for idx, output in enumerate(outputs):
if isinstance(output, dict) and "__root__" in output:
outputs[idx] = output["__root__"]
if isinstance(output, dict) and "root" in output:
outputs[idx] = output["root"]

if "callback_events" in obj:
if not isinstance(obj["callback_events"], list):
raise AssertionError("Expected callback_events to be a list")

for callback_events in obj["callback_events"]:
for idx, callback_event in enumerate(callback_events):
if (
isinstance(callback_event, dict)
and "__root__" in callback_event
):
callback_events[idx] = callback_event["__root__"]
if isinstance(callback_event, dict) and "root" in callback_event:
callback_events[idx] = callback_event["root"]
else:
raise AssertionError(
f"Expected a InvokeBaseResponse or BatchBaseResponse got: {type(model)}"
Expand Down
3 changes: 1 addition & 2 deletions langserve/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from fastapi.responses import Response
from langchain_core.runnables import Runnable

from langserve.pydantic_v1 import BaseModel
from pydantic import BaseModel


class PlaygroundTemplate(Template):
Expand Down
33 changes: 0 additions & 33 deletions langserve/pydantic_v1.py

This file was deleted.

9 changes: 5 additions & 4 deletions langserve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Dict, List, Optional, Union
from uuid import UUID

from pydantic import BaseModel # Floats between v1 and v2

from langserve.pydantic_v1 import BaseModel as BaseModelV1
from langserve.pydantic_v1 import Field
from pydantic import (
BaseModel,
Field,
)
from pydantic import BaseModel as BaseModelV1


class CustomUserType(BaseModelV1):
Expand Down
Loading
Loading