diff --git a/langserve/serialization.py b/langserve/serialization.py index a528be54..c208c0c1 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -161,26 +161,44 @@ def dumpd(self, obj: Any) -> Any: """Convert the given object to a JSON serializable object.""" return orjson.loads(self.dumps(obj)) + def loads(self, s: bytes) -> Any: + """Load the given JSON string.""" + return self.loadd(orjson.loads(s)) + @abc.abstractmethod def dumps(self, obj: Any) -> bytes: - """Dump the given object as a JSON string.""" + """Dump the given object to a JSON byte string.""" @abc.abstractmethod - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" + def loadd(self, s: bytes) -> Any: + """Given a python object, load it into a well known object. - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" - return self.loadd(orjson.loads(s)) + The obj represents content that was json loaded from a string, but + not yet validated or converted into a well known object. + """ class WellKnownLCSerializer(Serializer): + """A pre-defined serializer for well known LangChain objects. + + This is the default serialized used by LangServe for serializing and + de-serializing well known LangChain objects. + + If you need to extend the serialization capabilities for your own application, + feel free to create a new instance of the Serializer class and implement + the abstract methods dumps and loadd. + """ + def dumps(self, obj: Any) -> bytes: - """Dump the given object as a JSON string.""" + """Dump the given object to a JSON byte string.""" return orjson.dumps(obj, default=default) def loadd(self, obj: Any) -> Any: - """Load the given object.""" + """Given a python object, load it into a well known object. + + The obj represents content that was json loaded from a string, but + not yet validated or converted into a well known object. + """ return _decode_lc_objects(obj) diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 8ef470c6..fe328566 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -2293,26 +2293,28 @@ class CustomObject: def __init__(self, x: int) -> None: self.x = x - class CustomSerializer(Serializer): - def dumpd(self, obj: Any) -> Any: - """Convert the given object to a JSON serializable object.""" - return orjson.loads(orjson.dumps(obj)) + def __eq__(self, other) -> bool: + return self.x == other.x + class CustomSerializer(Serializer): def dumps(self, obj: Any) -> bytes: """Dump the given object as a JSON string.""" - return orjson.dumps(obj) + if isinstance(obj, CustomObject): + return orjson.dumps({"x": obj.x}) + else: + return orjson.dumps(obj) def loadd(self, obj: Any) -> Any: """Load the given object.""" - raise NotImplementedError() - - def loads(self, s: bytes) -> Any: - """Load the given JSON string.""" - return orjson.loads(s) + if isinstance(obj, bytes): + obj = obj.decode("utf-8") + if obj.get("x"): + return CustomObject(x=obj["x"]) + return obj def foo(x: int) -> Any: """Add one to simulate a valid function.""" - return 2 + return CustomObject(x=5) app = FastAPI() server_runnable = RunnableLambda(foo) @@ -2322,7 +2324,8 @@ def foo(x: int) -> Any: app, raise_app_exceptions=True, serializer=CustomSerializer() ) as runnable: result = await runnable.ainvoke(5) - assert result == {} + assert isinstance(result, CustomObject) + assert result == CustomObject(x=5) async def test_endpoint_configurations() -> None: