diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index acfa7c41..32817d01 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -536,13 +536,23 @@ def test_invoke(sync_remote_runnable: RemoteRunnable) -> None: assert remote_runnable_run.child_runs[0].name == "add_one_or_passthrough" +def test_foo_foo_bar_bar(sync_remote_runnable: RemoteRunnable) -> None: + tracer = FakeTracer() + assert sync_remote_runnable.batch([1], config={"callbacks": [tracer]}) == [2] + assert len(tracer.runs) == 1 + + def test_batch(sync_remote_runnable: RemoteRunnable) -> None: """Test sync batch.""" - assert sync_remote_runnable.batch([]) == [] - assert sync_remote_runnable.batch([1, 2, 3]) == [2, 3, 4] - assert sync_remote_runnable.batch([HumanMessage(content="hello")]) == [ - HumanMessage(content="hello") - ] + # assert sync_remote_runnable.batch([]) == [] + # assert sync_remote_runnable.batch([1, 2, 3]) == [2, 3, 4] + # assert sync_remote_runnable.batch([HumanMessage(content="hello")]) == [ + # HumanMessage(content="hello") + # ] + + tracer = FakeTracer() + assert sync_remote_runnable.batch([1, 1], config={"callbacks": [tracer]}) == [2, 3] + assert len(tracer.runs) == 1 # Test callbacks # Using a single tracer for both inputs @@ -552,7 +562,7 @@ def test_batch(sync_remote_runnable: RemoteRunnable) -> None: # Light test to verify that we're picking up information about the server side # function being invoked via a callback. - assert tracer.runs[0] == {} + # assert tracer.runs[0] == {} assert tracer.runs[0].child_runs[0].name == "RunnableLambda" assert ( tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" diff --git a/tests/unit_tests/utils/tracer.py b/tests/unit_tests/utils/tracer.py index a634b9b9..c1a41f78 100644 --- a/tests/unit_tests/utils/tracer.py +++ b/tests/unit_tests/utils/tracer.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Any, Optional from uuid import UUID from langchain_core.tracers import BaseTracer @@ -39,6 +39,35 @@ def _copy_run(self, run: Run) -> Run: } ) + def _create_chain_run( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + if name is None: + # can't raise an exception from here, but can get a breakpoint + import pdb; pdb.set_trace() + return super()._create_chain_run( + serialized, + inputs, + run_id, + tags, + parent_run_id, + metadata, + run_type, + name, + **kwargs, + + ) + + def _persist_run(self, run: Run) -> None: """Persist a run.""" self.runs.append(self._copy_run(run))