From b60480ed2c331d2c7dd7450aa39a642482ca6c37 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 14 Sep 2024 13:29:17 -0400 Subject: [PATCH] x --- langserve/callbacks.py | 18 +++++- langserve/validation.py | 80 ++++++++++---------------- tests/unit_tests/test_server_client.py | 80 +++++++------------------- tests/unit_tests/utils/tracer.py | 24 ++++---- 4 files changed, 79 insertions(+), 123 deletions(-) diff --git a/langserve/callbacks.py b/langserve/callbacks.py index fd129e99..18dc9420 100644 --- a/langserve/callbacks.py +++ b/langserve/callbacks.py @@ -445,7 +445,14 @@ async def ahandle_callbacks( if event["parent_run_id"] is None: # How do we make sure it's None!? event["parent_run_id"] = callback_manager.run_id - event_data = {key: value for key, value in event.items() if key != "type"} + event_data = { + key: value + for key, value in event.items() + if key != "type" and key != "kwargs" + } + + if "kwargs" in event: + event_data.update(event["kwargs"]) await ahandle_event( # Unpacking like this may not work @@ -467,7 +474,14 @@ def handle_callbacks( if event["parent_run_id"] is None: # How do we make sure it's None!? event["parent_run_id"] = callback_manager.run_id - event_data = {key: value for key, value in event.items() if key != "type"} + event_data = { + key: value + for key, value in event.items() + if key != "type" and key != "kwargs" + } + + if "kwargs" in event: + event_data.update(event["kwargs"]) handle_event( # Unpacking like this may not work diff --git a/langserve/validation.py b/langserve/validation.py index 9f4e254e..1256cdfb 100644 --- a/langserve/validation.py +++ b/langserve/validation.py @@ -396,27 +396,29 @@ class StreamEventsParameters(BaseModel): # status code and a message. -class OnChainStart(BaseModel): - """On Chain Start Callback Event.""" +class BaseCallback(BaseModel): + """Base class for all callback events.""" - serialized: Optional[Dict[str, Any]] = None - inputs: Any run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + + +class OnChainStart(BaseCallback): + """On Chain Start Callback Event.""" + + serialized: Optional[Dict[str, Any]] = None + inputs: Any + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chain_start"] = "on_chain_start" -class OnChainEnd(BaseModel): +class OnChainEnd(BaseCallback): """On Chain End Callback Event.""" outputs: Any - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chain_end"] = "on_chain_end" @@ -428,18 +430,15 @@ class Error(BaseModel): type: Literal["error"] = "error" -class OnChainError(BaseModel): +class OnChainError(BaseCallback): """On Chain Error Callback Event.""" error: Error - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chain_error"] = "on_chain_error" -class OnToolStart(BaseModel): +class OnToolStart(BaseCallback): """On Tool Start Callback Event.""" serialized: Optional[Dict[str, Any]] = None @@ -448,18 +447,18 @@ class OnToolStart(BaseModel): parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_tool_start"] = "on_tool_start" -class OnToolEnd(BaseModel): +class OnToolEnd(BaseCallback): """On Tool End Callback Event.""" output: str run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_tool_end"] = "on_tool_end" @@ -467,27 +466,20 @@ class OnToolError(BaseModel): """On Tool Error Callback Event.""" error: Error - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_tool_error"] = "on_tool_error" -class OnChatModelStart(BaseModel): +class OnChatModelStart(BaseCallback): """On Chat Model Start Callback Event.""" serialized: Optional[Dict[str, Any]] = None messages: List[List[BaseMessage]] - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chat_model_start"] = "on_chat_model_start" -class OnLLMStart(BaseModel): +class OnLLMStart(BaseCallback): """On LLM Start Callback Event.""" serialized: Optional[Dict[str, Any]] = None @@ -496,7 +488,7 @@ class OnLLMStart(BaseModel): parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_llm_start"] = "on_llm_start" @@ -515,49 +507,39 @@ class LLMResult(BaseModel): """List of metadata info for model call for each input.""" -class OnLLMEnd(BaseModel): +class OnLLMEnd(BaseCallback): """On LLM End Callback Event.""" response: LLMResult - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_llm_end"] = "on_llm_end" -class OnRetrieverStart(BaseModel): +class OnRetrieverStart(BaseCallback): """On Retriever Start Callback Event.""" serialized: Optional[Dict[str, Any]] = None query: str - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_retriever_start"] = "on_retriever_start" -class OnRetrieverError(BaseModel): +class OnRetrieverError(BaseCallback): """On Retriever Error Callback Event.""" error: Error run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_retriever_error"] = "on_retriever_error" -class OnRetrieverEnd(BaseModel): +class OnRetrieverEnd(BaseCallback): """On Retriever End Callback Event.""" documents: Sequence[Document] - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_retriever_end"] = "on_retriever_end" diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 6a7ffed0..cfce20af 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -536,26 +536,22 @@ def test_invoke(sync_remote_runnable: RemoteRunnable) -> None: assert remote_runnable_run.child_runs[0].name == "add_one_or_passthrough" -def test_foo_foo_bar_bar(sync_remote_runnable: RemoteRunnable) -> None: +def test_batch_tracer_with_single_input(sync_remote_runnable: RemoteRunnable) -> None: + """Test passing a single tracer to batch.""" tracer = FakeTracer() assert sync_remote_runnable.batch([1], config={"callbacks": [tracer]}) == [2] assert len(tracer.runs) == 1 - # Child run exists from server side (this fails) - assert len(tracer.runs[0].child_runs[0]) == 1 # - assert tracer.runs[0].child_runs[0].name == "RunnableLambda" + assert len(tracer.runs[0].child_runs) == 1 + assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough" def test_batch(sync_remote_runnable: RemoteRunnable) -> None: """Test sync batch.""" - # assert sync_remote_runnable.batch([]) == [] - # assert sync_remote_runnable.batch([1, 2, 3]) == [2, 3, 4] - # assert sync_remote_runnable.batch([HumanMessage(content="hello")]) == [ - # HumanMessage(content="hello") - # ] - - tracer = FakeTracer() - assert sync_remote_runnable.batch([1, 1], config={"callbacks": [tracer]}) == [2, 3] - assert len(tracer.runs) == 1 + assert sync_remote_runnable.batch([]) == [] + assert sync_remote_runnable.batch([1, 2, 3]) == [2, 3, 4] + assert sync_remote_runnable.batch([HumanMessage(content="hello")]) == [ + HumanMessage(content="hello") + ] # Test callbacks # Using a single tracer for both inputs @@ -563,18 +559,8 @@ def test_batch(sync_remote_runnable: RemoteRunnable) -> None: assert sync_remote_runnable.batch([1, 2], config={"callbacks": [tracer]}) == [2, 3] assert len(tracer.runs) == 2 - # Light test to verify that we're picking up information about the server side - # function being invoked via a callback. - # assert tracer.runs[0] == {} - assert tracer.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) - - assert tracer.runs[1].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[1].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) + assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer.runs[1].child_runs[0].name == "add_one_or_passthrough" # Verify that each tracer gets its own run tracer1 = FakeTracer() @@ -586,17 +572,8 @@ def test_batch(sync_remote_runnable: RemoteRunnable) -> None: assert len(tracer2.runs) == 1 # Light test to verify that we're picking up information about the server side # function being invoked via a callback. - assert tracer1.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer1.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) - - assert tracer2.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer2.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) + assert tracer1.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer2.runs[0].child_runs[0].name == "add_one_or_passthrough" async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None: @@ -617,7 +594,7 @@ async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None: # due to asyncio supporting contextvars starting from 3.11. # check the python version now if sys.version_info >= (3, 11): - assert len(tracer.runs) == 1, "Failed for python >= 3.11" + assert len(tracer.runs) == 2, "Failed for python >= 3.11" first_run = tracer.runs[0] remote_runnable_run = ( @@ -652,17 +629,9 @@ async def test_abatch(async_remote_runnable: RemoteRunnable) -> None: [1, 2], config={"callbacks": [tracer]} ) == [2, 3] assert len(tracer.runs) == 2 - # Light test to verify that we're picking up information about the server side - # function being invoked via a callback. - assert tracer.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) - assert tracer.runs[1].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[1].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) + assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer.runs[1].child_runs[0].name == "add_one_or_passthrough" # Verify that each tracer gets its own run tracer1 = FakeTracer() @@ -672,19 +641,9 @@ async def test_abatch(async_remote_runnable: RemoteRunnable) -> None: ) == [2, 3] assert len(tracer1.runs) == 1 assert len(tracer2.runs) == 1 - # Light test to verify that we're picking up information about the server side - # function being invoked via a callback. - assert tracer1.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer1.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) - assert tracer2.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer2.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) + assert tracer1.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer2.runs[0].child_runs[0].name == "add_one_or_passthrough" async def test_astream(async_remote_runnable: RemoteRunnable) -> None: @@ -1202,6 +1161,7 @@ async def add_one(x: int) -> int: { "kwargs": {}, "outputs": 2, + "metadata": None, "parent_run_id": None, "tags": [], "run_id": None, @@ -1264,6 +1224,7 @@ async def add_one(x: int) -> int: "kwargs": {}, "outputs": 2, "parent_run_id": None, + "metadata": None, "run_id": None, "tags": [], "type": "on_chain_end", @@ -1288,6 +1249,7 @@ async def add_one(x: int) -> int: "kwargs": {}, "outputs": 3, "parent_run_id": None, + "metadata": None, "run_id": None, "tags": [], "type": "on_chain_end", diff --git a/tests/unit_tests/utils/tracer.py b/tests/unit_tests/utils/tracer.py index 4bbf1103..bdb39f44 100644 --- a/tests/unit_tests/utils/tracer.py +++ b/tests/unit_tests/utils/tracer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Optional +from typing import Any, Dict, List, Optional from uuid import UUID from langchain_core.tracers import BaseTracer @@ -40,16 +40,16 @@ def _copy_run(self, run: Run) -> Run: ) def _create_chain_run( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, - **kwargs: Any, + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, ) -> Run: if name is None: # can't raise an exception from here, but can get a breakpoint @@ -65,10 +65,8 @@ def _create_chain_run( run_type, name, **kwargs, - ) - def _persist_run(self, run: Run) -> None: """Persist a run.""" self.runs.append(self._copy_run(run))