From 45ddb62a430dcb7300e67291917c3887b720a37d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 12 Sep 2024 17:05:08 -0400 Subject: [PATCH] qxqx --- langserve/client.py | 5 ++++- langserve/server.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/langserve/client.py b/langserve/client.py index 55968937..1dde56c6 100644 --- a/langserve/client.py +++ b/langserve/client.py @@ -285,6 +285,7 @@ def __init__( cert: Optional[CertTypes] = None, client_kwargs: Optional[Dict[str, Any]] = None, use_server_callback_events: bool = True, + serializer: Optional[Serializer] = None, ) -> None: """Initialize the client. @@ -300,6 +301,8 @@ def __init__( and async httpx clients use_server_callback_events: Whether to invoke callbacks on any callback events returned by the server. + serializer: The serializer to use for serializing and deserializing + data. If not provided, a default serializer will be used. """ _client_kwargs = client_kwargs or {} # Enforce trailing slash @@ -327,7 +330,7 @@ def __init__( # Register cleanup handler once RemoteRunnable is garbage collected weakref.finalize(self, _close_clients, self.sync_client, self.async_client) - self._lc_serializer = WellKnownLCSerializer() + self._lc_serializer = serializer or WellKnownLCSerializer() self._use_server_callback_events = use_server_callback_events def _invoke( diff --git a/langserve/server.py b/langserve/server.py index 182a7960..4602a1d3 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -26,6 +26,7 @@ TokenFeedbackConfig, _is_hosted, ) +from langserve.serialization import Serializer try: from fastapi import APIRouter, Depends, FastAPI, Request, Response @@ -263,6 +264,7 @@ def add_routes( dependencies: Optional[Sequence[Depends]] = None, playground_type: Literal["default", "chat"] = "default", astream_events_version: Literal["v1", "v2"] = "v2", + serializer: Optional[Serializer] = None, ) -> None: """Register the routes on the given FastAPI app or APIRouter. @@ -383,6 +385,8 @@ def add_routes( which message types are supported etc.) astream_events_version: version of the stream events endpoint to use. By default "v2". + serializer: The serializer to use for serializing the output. If not provided, + the default serializer will be used. """ # noqa: E501 if not isinstance(runnable, Runnable): raise TypeError( @@ -447,6 +451,7 @@ def add_routes( stream_log_name_allow_list=stream_log_name_allow_list, playground_type=playground_type, astream_events_version=astream_events_version, + serializer=serializer, ) namespace = path or ""