Skip to content

Commit

Permalink
Fix: Drop non json serializable values in the config prior to sending…
Browse files Browse the repository at this point in the history
… it to the server (#425)

This PR drops all non json serializable values in the config prior to sending the config to the server.

This seems like the correct behavior in general, which is why it's being merged without exposing a way to control it.

Any non serializable value seem to be used only locallly by the runnable or something that wraps. Server side configurable runnables are supposed to only exposed configuration that is trivially json serializable because that's how configurable runnables were designed.
  • Loading branch information
eyurtsev authored Jan 27, 2024
1 parent 534009a commit 1828f91
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 15 deletions.
84 changes: 71 additions & 13 deletions langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,68 @@
logger = logging.getLogger(__name__)


def _without_callbacks(config: Optional[RunnableConfig]) -> RunnableConfig:
"""Evict callbacks from the config since those are definitely not supported."""
def _is_json_serializable(obj: Any) -> bool:
"""Return True if the object is json serializable."""
if isinstance(obj, (tuple, list, dict, str, int, float, bool, type(None))):
return True
else:
return False


def _keep_json_serializable(obj: Any) -> Any:
"""Traverse the object recursively and removes non-json serializable objects."""
if isinstance(obj, dict):
return {
k: _keep_json_serializable(v)
for k, v in obj.items()
if isinstance(k, str) and _is_json_serializable(v)
}
elif isinstance(obj, (list, tuple)):
return [_keep_json_serializable(v) for v in obj if _is_json_serializable(v)]
elif _is_json_serializable(obj):
return obj
else:
raise AssertionError("This code should not be reachable. If it's reached")


def _prepare_config_for_server(
config: Optional[RunnableConfig], *, ignore_unserializable: bool = True
) -> RunnableConfig:
"""Evict information from the config that should not be sent to the server.
This includes:
- callbacks: Callbacks are handled separately
- non-json serializable objects: We cannot serialize then the correct behavior
these appear frequently in the config of the runnable but are only needed
in the local scope of the config (they do not need to be sent to the server).
An example are the write / read channel objects populated by langgraph,
or the 'messages' field in configurable populated by RunnableWithMessageHistory.
Args:
config: The config to clean up
ignore_unserializable: If True, will ignore non-json serializable objects
found in the 'configurable' field of the config.
This is expected to be the safe default to use since the server
should not be specifying configurable objects that are not json
serializable. This logic is expected mostly to with non serializable
content that was created for local use by the runnable, and
is not needed by the server.
If False, will raise an error if a non-json serializable object is found.
Returns:
A cleaned up version of the config that can be sent to the server.
"""
_config = config or {}
return {k: v for k, v in _config.items() if k != "callbacks"}
without_callbacks = {k: v for k, v in _config.items() if k != "callbacks"}
if "configurable" in without_callbacks:
# Get a version of

if ignore_unserializable:
without_callbacks["configurable"] = _keep_json_serializable(
without_callbacks["configurable"]
)

return without_callbacks


@lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages
Expand Down Expand Up @@ -277,7 +335,7 @@ def _invoke(
"/invoke",
json={
"input": self._lc_serializer.dumpd(input),
"config": _without_callbacks(config),
"config": _prepare_config_for_server(config),
"kwargs": kwargs,
},
)
Expand Down Expand Up @@ -308,7 +366,7 @@ async def _ainvoke(
"/invoke",
json={
"input": self._lc_serializer.dumpd(input),
"config": _without_callbacks(config),
"config": _prepare_config_for_server(config),
"kwargs": kwargs,
},
)
Expand Down Expand Up @@ -343,9 +401,9 @@ def _batch(
)

if isinstance(config, list):
_config = [_without_callbacks(c) for c in config]
_config = [_prepare_config_for_server(c) for c in config]
else:
_config = _without_callbacks(config)
_config = _prepare_config_for_server(config)

response = self.sync_client.post(
"/batch",
Expand Down Expand Up @@ -396,9 +454,9 @@ async def _abatch(
)

if isinstance(config, list):
_config = [_without_callbacks(c) for c in config]
_config = [_prepare_config_for_server(c) for c in config]
else:
_config = _without_callbacks(config)
_config = _prepare_config_for_server(config)

response = await self.async_client.post(
"/batch",
Expand Down Expand Up @@ -460,7 +518,7 @@ def stream(
)
data = {
"input": self._lc_serializer.dumpd(input),
"config": _without_callbacks(config),
"config": _prepare_config_for_server(config),
"kwargs": kwargs,
}
endpoint = urljoin(self.url, "stream")
Expand Down Expand Up @@ -546,7 +604,7 @@ async def astream(
)
data = {
"input": self._lc_serializer.dumpd(input),
"config": _without_callbacks(config),
"config": _prepare_config_for_server(config),
"kwargs": kwargs,
}
endpoint = urljoin(self.url, "stream")
Expand Down Expand Up @@ -648,7 +706,7 @@ async def astream_log(
)
data = {
"input": self._lc_serializer.dumpd(input),
"config": _without_callbacks(config),
"config": _prepare_config_for_server(config),
"kwargs": kwargs,
"diff": True,
"include_names": include_names,
Expand Down Expand Up @@ -754,7 +812,7 @@ async def astream_events(
)
data = {
"input": self._lc_serializer.dumpd(input),
"config": _without_callbacks(config),
"config": _prepare_config_for_server(config),
"kwargs": kwargs,
"include_names": include_names,
"include_types": include_types,
Expand Down
Empty file.
71 changes: 69 additions & 2 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
from langchain.schema.runnable.base import RunnableLambda
from langchain.schema.runnable.utils import ConfigurableField, Input, Output
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langsmith import schemas as ls_schemas
from pytest import MonkeyPatch
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -2719,3 +2721,68 @@ async def verify_token(x_token: Annotated[str, Header()]) -> None:
f"with {response.text}. "
f"Should not return 422 status code since we are passing the header."
)


async def test_remote_configurable_remote_runnable() -> None:
"""Test that a configurable a client runnable that's configurable works.
Here, we wrap the client runnable in a RunnableWithMessageHistory.
The test verifies that the extra information populated by RunnableWithMessageHistory
does not interfere with the serialization logic.
"""
app = FastAPI()

class InMemoryHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history."""

messages: List[BaseMessage] = Field(default_factory=list)

def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)

def clear(self) -> None:
self.messages = []

# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store = {}

def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryHistory()
return store[session_id]

prompt = ChatPromptTemplate.from_messages(
[
("system", "You're an assistant who's good at {ability}"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)

model = GenericFakeChatModel(messages=cycle([AIMessage(content="Hello World!")]))
chain = prompt | model

add_routes(app, chain)

# Invoke request
async with get_async_remote_runnable(app, raise_app_exceptions=False) as client:
chain_with_history = RunnableWithMessageHistory(
client,
# Uses the get_by_session_id function defined in the example
# above.
get_by_session_id,
input_messages_key="question",
history_messages_key="history",
)
result = await chain_with_history.ainvoke(
{"question": "hi"}, {"configurable": {"session_id": "1"}}
)
assert result == AIMessage(content="Hello World!")
assert store == {
"1": InMemoryHistory(
messages=[HumanMessage(content="hi"), AIMessage(content="Hello World!")]
)
}

0 comments on commit 1828f91

Please sign in to comment.