diff --git a/langserve/client.py b/langserve/client.py index 5e707c94..9cc92eee 100644 --- a/langserve/client.py +++ b/langserve/client.py @@ -49,10 +49,68 @@ logger = logging.getLogger(__name__) -def _without_callbacks(config: Optional[RunnableConfig]) -> RunnableConfig: - """Evict callbacks from the config since those are definitely not supported.""" +def _is_json_serializable(obj: Any) -> bool: + """Return True if the object is json serializable.""" + if isinstance(obj, (tuple, list, dict, str, int, float, bool, type(None))): + return True + else: + return False + + +def _keep_json_serializable(obj: Any) -> Any: + """Traverse the object recursively and removes non-json serializable objects.""" + if isinstance(obj, dict): + return { + k: _keep_json_serializable(v) + for k, v in obj.items() + if isinstance(k, str) and _is_json_serializable(v) + } + elif isinstance(obj, (list, tuple)): + return [_keep_json_serializable(v) for v in obj if _is_json_serializable(v)] + elif _is_json_serializable(obj): + return obj + else: + raise AssertionError("This code should not be reachable. If it's reached") + + +def _prepare_config_for_server( + config: Optional[RunnableConfig], *, ignore_unserializable: bool = True +) -> RunnableConfig: + """Evict information from the config that should not be sent to the server. + + This includes: + - callbacks: Callbacks are handled separately + - non-json serializable objects: We cannot serialize then the correct behavior + these appear frequently in the config of the runnable but are only needed + in the local scope of the config (they do not need to be sent to the server). + An example are the write / read channel objects populated by langgraph, + or the 'messages' field in configurable populated by RunnableWithMessageHistory. + + Args: + config: The config to clean up + ignore_unserializable: If True, will ignore non-json serializable objects + found in the 'configurable' field of the config. + This is expected to be the safe default to use since the server + should not be specifying configurable objects that are not json + serializable. This logic is expected mostly to with non serializable + content that was created for local use by the runnable, and + is not needed by the server. + If False, will raise an error if a non-json serializable object is found. + + Returns: + A cleaned up version of the config that can be sent to the server. + """ _config = config or {} - return {k: v for k, v in _config.items() if k != "callbacks"} + without_callbacks = {k: v for k, v in _config.items() if k != "callbacks"} + if "configurable" in without_callbacks: + # Get a version of + + if ignore_unserializable: + without_callbacks["configurable"] = _keep_json_serializable( + without_callbacks["configurable"] + ) + + return without_callbacks @lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages @@ -277,7 +335,7 @@ def _invoke( "/invoke", json={ "input": self._lc_serializer.dumpd(input), - "config": _without_callbacks(config), + "config": _prepare_config_for_server(config), "kwargs": kwargs, }, ) @@ -308,7 +366,7 @@ async def _ainvoke( "/invoke", json={ "input": self._lc_serializer.dumpd(input), - "config": _without_callbacks(config), + "config": _prepare_config_for_server(config), "kwargs": kwargs, }, ) @@ -343,9 +401,9 @@ def _batch( ) if isinstance(config, list): - _config = [_without_callbacks(c) for c in config] + _config = [_prepare_config_for_server(c) for c in config] else: - _config = _without_callbacks(config) + _config = _prepare_config_for_server(config) response = self.sync_client.post( "/batch", @@ -396,9 +454,9 @@ async def _abatch( ) if isinstance(config, list): - _config = [_without_callbacks(c) for c in config] + _config = [_prepare_config_for_server(c) for c in config] else: - _config = _without_callbacks(config) + _config = _prepare_config_for_server(config) response = await self.async_client.post( "/batch", @@ -460,7 +518,7 @@ def stream( ) data = { "input": self._lc_serializer.dumpd(input), - "config": _without_callbacks(config), + "config": _prepare_config_for_server(config), "kwargs": kwargs, } endpoint = urljoin(self.url, "stream") @@ -546,7 +604,7 @@ async def astream( ) data = { "input": self._lc_serializer.dumpd(input), - "config": _without_callbacks(config), + "config": _prepare_config_for_server(config), "kwargs": kwargs, } endpoint = urljoin(self.url, "stream") @@ -648,7 +706,7 @@ async def astream_log( ) data = { "input": self._lc_serializer.dumpd(input), - "config": _without_callbacks(config), + "config": _prepare_config_for_server(config), "kwargs": kwargs, "diff": True, "include_names": include_names, @@ -754,7 +812,7 @@ async def astream_events( ) data = { "input": self._lc_serializer.dumpd(input), - "config": _without_callbacks(config), + "config": _prepare_config_for_server(config), "kwargs": kwargs, "include_names": include_names, "include_types": include_types, diff --git a/tests/unit_tests/test_remote_client.py b/tests/unit_tests/test_remote_client.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 336d9299..bf4eb926 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -33,11 +33,13 @@ from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough from langchain.schema.runnable.base import RunnableLambda from langchain.schema.runnable.utils import ConfigurableField, Input, Output +from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.documents import Document -from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.outputs import ChatGenerationChunk, LLMResult -from langchain_core.prompts import ChatPromptTemplate +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory from langsmith import schemas as ls_schemas from pytest import MonkeyPatch from pytest_mock import MockerFixture @@ -2719,3 +2721,68 @@ async def verify_token(x_token: Annotated[str, Header()]) -> None: f"with {response.text}. " f"Should not return 422 status code since we are passing the header." ) + + +async def test_remote_configurable_remote_runnable() -> None: + """Test that a configurable a client runnable that's configurable works. + + Here, we wrap the client runnable in a RunnableWithMessageHistory. + + The test verifies that the extra information populated by RunnableWithMessageHistory + does not interfere with the serialization logic. + """ + app = FastAPI() + + class InMemoryHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history.""" + + messages: List[BaseMessage] = Field(default_factory=list) + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + self.messages.append(message) + + def clear(self) -> None: + self.messages = [] + + # Here we use a global variable to store the chat message history. + # This will make it easier to inspect it to see the underlying results. + store = {} + + def get_by_session_id(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = InMemoryHistory() + return store[session_id] + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You're an assistant who's good at {ability}"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + + model = GenericFakeChatModel(messages=cycle([AIMessage(content="Hello World!")])) + chain = prompt | model + + add_routes(app, chain) + + # Invoke request + async with get_async_remote_runnable(app, raise_app_exceptions=False) as client: + chain_with_history = RunnableWithMessageHistory( + client, + # Uses the get_by_session_id function defined in the example + # above. + get_by_session_id, + input_messages_key="question", + history_messages_key="history", + ) + result = await chain_with_history.ainvoke( + {"question": "hi"}, {"configurable": {"session_id": "1"}} + ) + assert result == AIMessage(content="Hello World!") + assert store == { + "1": InMemoryHistory( + messages=[HumanMessage(content="hi"), AIMessage(content="Hello World!")] + ) + }