From 1e24edce0807ee223f8c7841aefabaebf7059d7b Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 14 Sep 2024 14:06:08 -0400 Subject: [PATCH] Add ability to specify custom serializer (#764) Allow users to define a custom serializer --- langserve/api_handler.py | 7 +++- langserve/client.py | 5 ++- langserve/serialization.py | 42 +++++++++++++-------- langserve/server.py | 5 +++ tests/unit_tests/test_server_client.py | 52 +++++++++++++++++++++++++- 5 files changed, 90 insertions(+), 21 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 79cd3277..51cea2c1 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -62,7 +62,7 @@ PublicTraceLink, PublicTraceLinkCreateRequest, ) -from langserve.serialization import WellKnownLCSerializer +from langserve.serialization import Serializer, WellKnownLCSerializer from langserve.validation import ( BatchBaseResponse, BatchRequestShallowValidator, @@ -536,6 +536,7 @@ def __init__( stream_log_name_allow_list: Optional[Sequence[str]] = None, playground_type: Literal["default", "chat"] = "default", astream_events_version: Literal["v1", "v2"] = "v2", + serializer: Optional[Serializer] = None, ) -> None: """Create an API handler for the given runnable. @@ -600,6 +601,8 @@ def __init__( TODO: Introduce deprecation for this parameter to rename it astream_events_version: version of the stream events endpoint to use. By default "v2". + serializer: optional serializer to use for serializing the output. + If not provided, the default serializer will be used. """ if importlib.util.find_spec("sse_starlette") is None: raise ImportError( @@ -638,7 +641,7 @@ def __init__( ) self._include_callback_events = include_callback_events self._per_req_config_modifier = per_req_config_modifier - self._serializer = WellKnownLCSerializer() + self._serializer = serializer or WellKnownLCSerializer() self._enable_feedback_endpoint = enable_feedback_endpoint self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint self._names_in_stream_allow_list = stream_log_name_allow_list diff --git a/langserve/client.py b/langserve/client.py index 55968937..1dde56c6 100644 --- a/langserve/client.py +++ b/langserve/client.py @@ -285,6 +285,7 @@ def __init__( cert: Optional[CertTypes] = None, client_kwargs: Optional[Dict[str, Any]] = None, use_server_callback_events: bool = True, + serializer: Optional[Serializer] = None, ) -> None: """Initialize the client. @@ -300,6 +301,8 @@ def __init__( and async httpx clients use_server_callback_events: Whether to invoke callbacks on any callback events returned by the server. + serializer: The serializer to use for serializing and deserializing + data. If not provided, a default serializer will be used. """ _client_kwargs = client_kwargs or {} # Enforce trailing slash @@ -327,7 +330,7 @@ def __init__( # Register cleanup handler once RemoteRunnable is garbage collected weakref.finalize(self, _close_clients, self.sync_client, self.async_client) - self._lc_serializer = WellKnownLCSerializer() + self._lc_serializer = serializer or WellKnownLCSerializer() self._use_server_callback_events = use_server_callback_events def _invoke( diff --git a/langserve/serialization.py b/langserve/serialization.py index 5093d5de..c208c0c1 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -157,39 +157,49 @@ def _decode_event_data(value: Any) -> Any: class Serializer(abc.ABC): - @abc.abstractmethod def dumpd(self, obj: Any) -> Any: """Convert the given object to a JSON serializable object.""" + return orjson.loads(self.dumps(obj)) + + def loads(self, s: bytes) -> Any: + """Load the given JSON string.""" + return self.loadd(orjson.loads(s)) @abc.abstractmethod def dumps(self, obj: Any) -> bytes: - """Dump the given object as a JSON string.""" + """Dump the given object to a JSON byte string.""" @abc.abstractmethod - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" + def loadd(self, s: bytes) -> Any: + """Given a python object, load it into a well known object. - @abc.abstractmethod - def loadd(self, obj: Any) -> Any: - """Load the given object.""" + The obj represents content that was json loaded from a string, but + not yet validated or converted into a well known object. + """ class WellKnownLCSerializer(Serializer): - def dumpd(self, obj: Any) -> Any: - """Convert the given object to a JSON serializable object.""" - return orjson.loads(orjson.dumps(obj, default=default)) + """A pre-defined serializer for well known LangChain objects. + + This is the default serialized used by LangServe for serializing and + de-serializing well known LangChain objects. + + If you need to extend the serialization capabilities for your own application, + feel free to create a new instance of the Serializer class and implement + the abstract methods dumps and loadd. + """ def dumps(self, obj: Any) -> bytes: - """Dump the given object as a JSON string.""" + """Dump the given object to a JSON byte string.""" return orjson.dumps(obj, default=default) def loadd(self, obj: Any) -> Any: - """Load the given object.""" - return _decode_lc_objects(obj) + """Given a python object, load it into a well known object. - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" - return self.loadd(orjson.loads(s)) + The obj represents content that was json loaded from a string, but + not yet validated or converted into a well known object. + """ + return _decode_lc_objects(obj) def _project_top_level(model: BaseModel) -> Dict[str, Any]: diff --git a/langserve/server.py b/langserve/server.py index 182a7960..4602a1d3 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -26,6 +26,7 @@ TokenFeedbackConfig, _is_hosted, ) +from langserve.serialization import Serializer try: from fastapi import APIRouter, Depends, FastAPI, Request, Response @@ -263,6 +264,7 @@ def add_routes( dependencies: Optional[Sequence[Depends]] = None, playground_type: Literal["default", "chat"] = "default", astream_events_version: Literal["v1", "v2"] = "v2", + serializer: Optional[Serializer] = None, ) -> None: """Register the routes on the given FastAPI app or APIRouter. @@ -383,6 +385,8 @@ def add_routes( which message types are supported etc.) astream_events_version: version of the stream events endpoint to use. By default "v2". + serializer: The serializer to use for serializing the output. If not provided, + the default serializer will be used. """ # noqa: E501 if not isinstance(runnable, Runnable): raise TypeError( @@ -447,6 +451,7 @@ def add_routes( stream_log_name_allow_list=stream_log_name_allow_list, playground_type=playground_type, astream_events_version=astream_events_version, + serializer=serializer, ) namespace = path or "" diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 2c8902b5..fe328566 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -56,6 +56,7 @@ from langsmith import schemas as ls_schemas from langsmith.client import Client from langsmith.schemas import FeedbackIngestToken +from orjson import orjson from pydantic import BaseModel, Field, __version__ from pytest import MonkeyPatch from pytest_mock import MockerFixture @@ -244,13 +245,17 @@ async def get_async_test_client( @asynccontextmanager async def get_async_remote_runnable( - server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True + server: FastAPI, + *, + path: Optional[str] = None, + raise_app_exceptions: bool = True, + **kwargs: Any, ) -> RemoteRunnable: """Get an async client.""" url = "http://localhost:9999" if path: url += path - remote_runnable_client = RemoteRunnable(url=url) + remote_runnable_client = RemoteRunnable(url=url, **kwargs) async with get_async_test_client( server, path=path, raise_app_exceptions=raise_app_exceptions @@ -2280,6 +2285,49 @@ async def check_types(inputs: VariousTypes) -> int: ) +async def test_custom_serialization() -> None: + """Test updating the config based on the raw request object.""" + from langserve.serialization import Serializer + + class CustomObject: + def __init__(self, x: int) -> None: + self.x = x + + def __eq__(self, other) -> bool: + return self.x == other.x + + class CustomSerializer(Serializer): + def dumps(self, obj: Any) -> bytes: + """Dump the given object as a JSON string.""" + if isinstance(obj, CustomObject): + return orjson.dumps({"x": obj.x}) + else: + return orjson.dumps(obj) + + def loadd(self, obj: Any) -> Any: + """Load the given object.""" + if isinstance(obj, bytes): + obj = obj.decode("utf-8") + if obj.get("x"): + return CustomObject(x=obj["x"]) + return obj + + def foo(x: int) -> Any: + """Add one to simulate a valid function.""" + return CustomObject(x=5) + + app = FastAPI() + server_runnable = RunnableLambda(foo) + add_routes(app, server_runnable, serializer=CustomSerializer()) + + async with get_async_remote_runnable( + app, raise_app_exceptions=True, serializer=CustomSerializer() + ) as runnable: + result = await runnable.ainvoke(5) + assert isinstance(result, CustomObject) + assert result == CustomObject(x=5) + + async def test_endpoint_configurations() -> None: """Test enabling/disabling endpoints.""" app = FastAPI()