From b89eb80feef6f1f1ab14d89aa84bdb74366cbd6f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 12 Sep 2024 17:02:54 -0400 Subject: [PATCH 1/5] x --- langserve/api_handler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 79cd3277..e4d5d655 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -62,7 +62,7 @@ PublicTraceLink, PublicTraceLinkCreateRequest, ) -from langserve.serialization import WellKnownLCSerializer +from langserve.serialization import WellKnownLCSerializer, Serializer from langserve.validation import ( BatchBaseResponse, BatchRequestShallowValidator, @@ -536,6 +536,7 @@ def __init__( stream_log_name_allow_list: Optional[Sequence[str]] = None, playground_type: Literal["default", "chat"] = "default", astream_events_version: Literal["v1", "v2"] = "v2", + serializer: Optional[Serializer] = None, ) -> None: """Create an API handler for the given runnable. @@ -600,6 +601,8 @@ def __init__( TODO: Introduce deprecation for this parameter to rename it astream_events_version: version of the stream events endpoint to use. By default "v2". + serializer: optional serializer to use for serializing the output. + If not provided, the default serializer will be used. """ if importlib.util.find_spec("sse_starlette") is None: raise ImportError( @@ -638,7 +641,7 @@ def __init__( ) self._include_callback_events = include_callback_events self._per_req_config_modifier = per_req_config_modifier - self._serializer = WellKnownLCSerializer() + self._serializer = serializer or WellKnownLCSerializer() self._enable_feedback_endpoint = enable_feedback_endpoint self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint self._names_in_stream_allow_list = stream_log_name_allow_list From 45ddb62a430dcb7300e67291917c3887b720a37d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 12 Sep 2024 17:05:08 -0400 Subject: [PATCH 2/5] qxqx --- langserve/client.py | 5 ++++- langserve/server.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/langserve/client.py b/langserve/client.py index 55968937..1dde56c6 100644 --- a/langserve/client.py +++ b/langserve/client.py @@ -285,6 +285,7 @@ def __init__( cert: Optional[CertTypes] = None, client_kwargs: Optional[Dict[str, Any]] = None, use_server_callback_events: bool = True, + serializer: Optional[Serializer] = None, ) -> None: """Initialize the client. @@ -300,6 +301,8 @@ def __init__( and async httpx clients use_server_callback_events: Whether to invoke callbacks on any callback events returned by the server. + serializer: The serializer to use for serializing and deserializing + data. If not provided, a default serializer will be used. """ _client_kwargs = client_kwargs or {} # Enforce trailing slash @@ -327,7 +330,7 @@ def __init__( # Register cleanup handler once RemoteRunnable is garbage collected weakref.finalize(self, _close_clients, self.sync_client, self.async_client) - self._lc_serializer = WellKnownLCSerializer() + self._lc_serializer = serializer or WellKnownLCSerializer() self._use_server_callback_events = use_server_callback_events def _invoke( diff --git a/langserve/server.py b/langserve/server.py index 182a7960..4602a1d3 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -26,6 +26,7 @@ TokenFeedbackConfig, _is_hosted, ) +from langserve.serialization import Serializer try: from fastapi import APIRouter, Depends, FastAPI, Request, Response @@ -263,6 +264,7 @@ def add_routes( dependencies: Optional[Sequence[Depends]] = None, playground_type: Literal["default", "chat"] = "default", astream_events_version: Literal["v1", "v2"] = "v2", + serializer: Optional[Serializer] = None, ) -> None: """Register the routes on the given FastAPI app or APIRouter. @@ -383,6 +385,8 @@ def add_routes( which message types are supported etc.) astream_events_version: version of the stream events endpoint to use. By default "v2". + serializer: The serializer to use for serializing the output. If not provided, + the default serializer will be used. """ # noqa: E501 if not isinstance(runnable, Runnable): raise TypeError( @@ -447,6 +451,7 @@ def add_routes( stream_log_name_allow_list=stream_log_name_allow_list, playground_type=playground_type, astream_events_version=astream_events_version, + serializer=serializer, ) namespace = path or "" From f04828122ba5b5334a3119045bb614d070c859e0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 12 Sep 2024 17:19:20 -0400 Subject: [PATCH 3/5] x --- tests/unit_tests/test_server_client.py | 49 ++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index c87a69ee..e196bf68 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -56,6 +56,7 @@ from langsmith import schemas as ls_schemas from langsmith.client import Client from langsmith.schemas import FeedbackIngestToken +from orjson import orjson from pydantic import BaseModel, Field, __version__ from pytest import MonkeyPatch from pytest_mock import MockerFixture @@ -231,13 +232,17 @@ async def get_async_test_client( @asynccontextmanager async def get_async_remote_runnable( - server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True + server: FastAPI, + *, + path: Optional[str] = None, + raise_app_exceptions: bool = True, + **kwargs: Any, ) -> RemoteRunnable: """Get an async client.""" url = "http://localhost:9999" if path: url += path - remote_runnable_client = RemoteRunnable(url=url) + remote_runnable_client = RemoteRunnable(url=url, **kwargs) async with get_async_test_client( server, path=path, raise_app_exceptions=raise_app_exceptions @@ -2146,6 +2151,46 @@ async def check_types(inputs: VariousTypes) -> int: ) +async def test_custom_serialization() -> None: + """Test updating the config based on the raw request object.""" + from langserve.serialization import Serializer + + class CustomObject: + def __init__(self, x: int) -> None: + self.x = x + + class CustomSerializer(Serializer): + def dumpd(self, obj: Any) -> Any: + """Convert the given object to a JSON serializable object.""" + return orjson.loads(orjson.dumps(obj)) + + def dumps(self, obj: Any) -> bytes: + """Dump the given object as a JSON string.""" + return orjson.dumps(obj) + + def loadd(self, obj: Any) -> Any: + """Load the given object.""" + raise NotImplementedError() + + def loads(self, s: bytes) -> Any: + """Load the given JSON string.""" + return orjson.loads(s) + + def foo(x: int) -> Any: + """Add one to simulate a valid function.""" + return 2 + + app = FastAPI() + server_runnable = RunnableLambda(foo) + add_routes(app, server_runnable, serializer=CustomSerializer()) + + async with get_async_remote_runnable( + app, raise_app_exceptions=True, serializer=CustomSerializer() + ) as runnable: + result = await runnable.ainvoke(5) + assert result == {} + + async def test_endpoint_configurations() -> None: """Test enabling/disabling endpoints.""" app = FastAPI() From 3310b006fdcd0d00cab2f204d85203347bb218b2 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 14 Sep 2024 13:53:53 -0400 Subject: [PATCH 4/5] x --- langserve/api_handler.py | 2 +- langserve/serialization.py | 16 ++++------------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index e4d5d655..51cea2c1 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -62,7 +62,7 @@ PublicTraceLink, PublicTraceLinkCreateRequest, ) -from langserve.serialization import WellKnownLCSerializer, Serializer +from langserve.serialization import Serializer, WellKnownLCSerializer from langserve.validation import ( BatchBaseResponse, BatchRequestShallowValidator, diff --git a/langserve/serialization.py b/langserve/serialization.py index 5093d5de..a528be54 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -157,9 +157,9 @@ def _decode_event_data(value: Any) -> Any: class Serializer(abc.ABC): - @abc.abstractmethod def dumpd(self, obj: Any) -> Any: """Convert the given object to a JSON serializable object.""" + return orjson.loads(self.dumps(obj)) @abc.abstractmethod def dumps(self, obj: Any) -> bytes: @@ -169,16 +169,12 @@ def dumps(self, obj: Any) -> bytes: def loads(self, s: bytes) -> Any: """Load the given JSON string.""" - @abc.abstractmethod - def loadd(self, obj: Any) -> Any: - """Load the given object.""" + def loads(self, s: bytes) -> Any: + """Load the given JSON string.""" + return self.loadd(orjson.loads(s)) class WellKnownLCSerializer(Serializer): - def dumpd(self, obj: Any) -> Any: - """Convert the given object to a JSON serializable object.""" - return orjson.loads(orjson.dumps(obj, default=default)) - def dumps(self, obj: Any) -> bytes: """Dump the given object as a JSON string.""" return orjson.dumps(obj, default=default) @@ -187,10 +183,6 @@ def loadd(self, obj: Any) -> Any: """Load the given object.""" return _decode_lc_objects(obj) - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" - return self.loadd(orjson.loads(s)) - def _project_top_level(model: BaseModel) -> Dict[str, Any]: """Project the top level of the model as dict.""" From 71657c59a96dde16d41984db6d2da8f85e591e39 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 14 Sep 2024 14:04:37 -0400 Subject: [PATCH 5/5] x --- langserve/serialization.py | 34 ++++++++++++++++++++------ tests/unit_tests/test_server_client.py | 27 +++++++++++--------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/langserve/serialization.py b/langserve/serialization.py index a528be54..c208c0c1 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -161,26 +161,44 @@ def dumpd(self, obj: Any) -> Any: """Convert the given object to a JSON serializable object.""" return orjson.loads(self.dumps(obj)) + def loads(self, s: bytes) -> Any: + """Load the given JSON string.""" + return self.loadd(orjson.loads(s)) + @abc.abstractmethod def dumps(self, obj: Any) -> bytes: - """Dump the given object as a JSON string.""" + """Dump the given object to a JSON byte string.""" @abc.abstractmethod - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" + def loadd(self, s: bytes) -> Any: + """Given a python object, load it into a well known object. - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" - return self.loadd(orjson.loads(s)) + The obj represents content that was json loaded from a string, but + not yet validated or converted into a well known object. + """ class WellKnownLCSerializer(Serializer): + """A pre-defined serializer for well known LangChain objects. + + This is the default serialized used by LangServe for serializing and + de-serializing well known LangChain objects. + + If you need to extend the serialization capabilities for your own application, + feel free to create a new instance of the Serializer class and implement + the abstract methods dumps and loadd. + """ + def dumps(self, obj: Any) -> bytes: - """Dump the given object as a JSON string.""" + """Dump the given object to a JSON byte string.""" return orjson.dumps(obj, default=default) def loadd(self, obj: Any) -> Any: - """Load the given object.""" + """Given a python object, load it into a well known object. + + The obj represents content that was json loaded from a string, but + not yet validated or converted into a well known object. + """ return _decode_lc_objects(obj) diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 8ef470c6..fe328566 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -2293,26 +2293,28 @@ class CustomObject: def __init__(self, x: int) -> None: self.x = x - class CustomSerializer(Serializer): - def dumpd(self, obj: Any) -> Any: - """Convert the given object to a JSON serializable object.""" - return orjson.loads(orjson.dumps(obj)) + def __eq__(self, other) -> bool: + return self.x == other.x + class CustomSerializer(Serializer): def dumps(self, obj: Any) -> bytes: """Dump the given object as a JSON string.""" - return orjson.dumps(obj) + if isinstance(obj, CustomObject): + return orjson.dumps({"x": obj.x}) + else: + return orjson.dumps(obj) def loadd(self, obj: Any) -> Any: """Load the given object.""" - raise NotImplementedError() - - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" - return orjson.loads(s) + if isinstance(obj, bytes): + obj = obj.decode("utf-8") + if obj.get("x"): + return CustomObject(x=obj["x"]) + return obj def foo(x: int) -> Any: """Add one to simulate a valid function.""" - return 2 + return CustomObject(x=5) app = FastAPI() server_runnable = RunnableLambda(foo) @@ -2322,7 +2324,8 @@ def foo(x: int) -> Any: app, raise_app_exceptions=True, serializer=CustomSerializer() ) as runnable: result = await runnable.ainvoke(5) - assert result == {} + assert isinstance(result, CustomObject) + assert result == CustomObject(x=5) async def test_endpoint_configurations() -> None: