Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Sep 6, 2024
1 parent 86cf510 commit 8f9b8b2
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 5 deletions.
53 changes: 53 additions & 0 deletions langserve/_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Dict, Type, cast

from pydantic import BaseModel, ConfigDict, RootModel
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)


def _create_root_model(name: str, type_: Any) -> Type[RootModel]:
"""Create a base class."""

def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_

def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_

base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
# Should replace __module__ with caller based on stack frame.
"__module__": "langserve._pydantic",
}

custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[RootModel], custom_root_type)
1 change: 0 additions & 1 deletion langserve/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
ChatGeneration,
ChatGenerationChunk,
Generation,
LLMResult,
)
from langchain_core.prompt_values import ChatPromptValueConcrete
from langchain_core.prompts.base import StringPromptValue
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
from langchain_core.documents.base import Document
from langchain_core.messages import HumanMessage, HumanMessageChunk, SystemMessage
from langchain_core.outputs import ChatGeneration
from pydantic import BaseModel

from langserve.serialization import (
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
SystemMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts import (
ChatPromptTemplate,
Expand Down Expand Up @@ -75,7 +74,7 @@
from langserve.server import add_routes
from tests.unit_tests.utils.llms import FakeListLLM, GenericFakeChatModel
from tests.unit_tests.utils.serde import recursive_dump
from tests.unit_tests.utils.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.utils.stubs import _AnyIdAIMessage
from tests.unit_tests.utils.tracer import FakeTracer


Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/utils/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from uuid import UUID

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk

from tests.unit_tests.utils.llms import GenericFakeChatModel
Expand Down

0 comments on commit 8f9b8b2

Please sign in to comment.