Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Sep 14, 2024
1 parent 51f785e commit b60480e
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 123 deletions.
18 changes: 16 additions & 2 deletions langserve/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,14 @@ async def ahandle_callbacks(
if event["parent_run_id"] is None: # How do we make sure it's None!?
event["parent_run_id"] = callback_manager.run_id

event_data = {key: value for key, value in event.items() if key != "type"}
event_data = {
key: value
for key, value in event.items()
if key != "type" and key != "kwargs"
}

if "kwargs" in event:
event_data.update(event["kwargs"])

await ahandle_event(
# Unpacking like this may not work
Expand All @@ -467,7 +474,14 @@ def handle_callbacks(
if event["parent_run_id"] is None: # How do we make sure it's None!?
event["parent_run_id"] = callback_manager.run_id

event_data = {key: value for key, value in event.items() if key != "type"}
event_data = {
key: value
for key, value in event.items()
if key != "type" and key != "kwargs"
}

if "kwargs" in event:
event_data.update(event["kwargs"])

handle_event(
# Unpacking like this may not work
Expand Down
80 changes: 31 additions & 49 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,27 +396,29 @@ class StreamEventsParameters(BaseModel):
# status code and a message.


class OnChainStart(BaseModel):
"""On Chain Start Callback Event."""
class BaseCallback(BaseModel):
"""Base class for all callback events."""

serialized: Optional[Dict[str, Any]] = None
inputs: Any
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None


class OnChainStart(BaseCallback):
"""On Chain Start Callback Event."""

serialized: Optional[Dict[str, Any]] = None
inputs: Any
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_start"] = "on_chain_start"


class OnChainEnd(BaseModel):
class OnChainEnd(BaseCallback):
"""On Chain End Callback Event."""

outputs: Any
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_end"] = "on_chain_end"


Expand All @@ -428,18 +430,15 @@ class Error(BaseModel):
type: Literal["error"] = "error"


class OnChainError(BaseModel):
class OnChainError(BaseCallback):
"""On Chain Error Callback Event."""

error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_error"] = "on_chain_error"


class OnToolStart(BaseModel):
class OnToolStart(BaseCallback):
"""On Tool Start Callback Event."""

serialized: Optional[Dict[str, Any]] = None
Expand All @@ -448,46 +447,39 @@ class OnToolStart(BaseModel):
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_start"] = "on_tool_start"


class OnToolEnd(BaseModel):
class OnToolEnd(BaseCallback):
"""On Tool End Callback Event."""

output: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_end"] = "on_tool_end"


class OnToolError(BaseModel):
"""On Tool Error Callback Event."""

error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_error"] = "on_tool_error"


class OnChatModelStart(BaseModel):
class OnChatModelStart(BaseCallback):
"""On Chat Model Start Callback Event."""

serialized: Optional[Dict[str, Any]] = None
messages: List[List[BaseMessage]]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chat_model_start"] = "on_chat_model_start"


class OnLLMStart(BaseModel):
class OnLLMStart(BaseCallback):
"""On LLM Start Callback Event."""

serialized: Optional[Dict[str, Any]] = None
Expand All @@ -496,7 +488,7 @@ class OnLLMStart(BaseModel):
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_llm_start"] = "on_llm_start"


Expand All @@ -515,49 +507,39 @@ class LLMResult(BaseModel):
"""List of metadata info for model call for each input."""


class OnLLMEnd(BaseModel):
class OnLLMEnd(BaseCallback):
"""On LLM End Callback Event."""

response: LLMResult
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_llm_end"] = "on_llm_end"


class OnRetrieverStart(BaseModel):
class OnRetrieverStart(BaseCallback):
"""On Retriever Start Callback Event."""

serialized: Optional[Dict[str, Any]] = None
query: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_start"] = "on_retriever_start"


class OnRetrieverError(BaseModel):
class OnRetrieverError(BaseCallback):
"""On Retriever Error Callback Event."""

error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_error"] = "on_retriever_error"


class OnRetrieverEnd(BaseModel):
class OnRetrieverEnd(BaseCallback):
"""On Retriever End Callback Event."""

documents: Sequence[Document]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_end"] = "on_retriever_end"


Expand Down
80 changes: 21 additions & 59 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,45 +536,31 @@ def test_invoke(sync_remote_runnable: RemoteRunnable) -> None:
assert remote_runnable_run.child_runs[0].name == "add_one_or_passthrough"


def test_foo_foo_bar_bar(sync_remote_runnable: RemoteRunnable) -> None:
def test_batch_tracer_with_single_input(sync_remote_runnable: RemoteRunnable) -> None:
"""Test passing a single tracer to batch."""
tracer = FakeTracer()
assert sync_remote_runnable.batch([1], config={"callbacks": [tracer]}) == [2]
assert len(tracer.runs) == 1
# Child run exists from server side (this fails)
assert len(tracer.runs[0].child_runs[0]) == 1 #
assert tracer.runs[0].child_runs[0].name == "RunnableLambda"
assert len(tracer.runs[0].child_runs) == 1
assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough"


def test_batch(sync_remote_runnable: RemoteRunnable) -> None:
"""Test sync batch."""
# assert sync_remote_runnable.batch([]) == []
# assert sync_remote_runnable.batch([1, 2, 3]) == [2, 3, 4]
# assert sync_remote_runnable.batch([HumanMessage(content="hello")]) == [
# HumanMessage(content="hello")
# ]

tracer = FakeTracer()
assert sync_remote_runnable.batch([1, 1], config={"callbacks": [tracer]}) == [2, 3]
assert len(tracer.runs) == 1
assert sync_remote_runnable.batch([]) == []
assert sync_remote_runnable.batch([1, 2, 3]) == [2, 3, 4]
assert sync_remote_runnable.batch([HumanMessage(content="hello")]) == [
HumanMessage(content="hello")
]

# Test callbacks
# Using a single tracer for both inputs
tracer = FakeTracer()
assert sync_remote_runnable.batch([1, 2], config={"callbacks": [tracer]}) == [2, 3]
assert len(tracer.runs) == 2

# Light test to verify that we're picking up information about the server side
# function being invoked via a callback.
# assert tracer.runs[0] == {}
assert tracer.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough"
)

assert tracer.runs[1].child_runs[0].name == "RunnableLambda"
assert (
tracer.runs[1].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough"
)
assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough"
assert tracer.runs[1].child_runs[0].name == "add_one_or_passthrough"

# Verify that each tracer gets its own run
tracer1 = FakeTracer()
Expand All @@ -586,17 +572,8 @@ def test_batch(sync_remote_runnable: RemoteRunnable) -> None:
assert len(tracer2.runs) == 1
# Light test to verify that we're picking up information about the server side
# function being invoked via a callback.
assert tracer1.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer1.runs[0].child_runs[0].extra["kwargs"]["name"]
== "add_one_or_passthrough"
)

assert tracer2.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer2.runs[0].child_runs[0].extra["kwargs"]["name"]
== "add_one_or_passthrough"
)
assert tracer1.runs[0].child_runs[0].name == "add_one_or_passthrough"
assert tracer2.runs[0].child_runs[0].name == "add_one_or_passthrough"


async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None:
Expand All @@ -617,7 +594,7 @@ async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None:
# due to asyncio supporting contextvars starting from 3.11.
# check the python version now
if sys.version_info >= (3, 11):
assert len(tracer.runs) == 1, "Failed for python >= 3.11"
assert len(tracer.runs) == 2, "Failed for python >= 3.11"
first_run = tracer.runs[0]

remote_runnable_run = (
Expand Down Expand Up @@ -652,17 +629,9 @@ async def test_abatch(async_remote_runnable: RemoteRunnable) -> None:
[1, 2], config={"callbacks": [tracer]}
) == [2, 3]
assert len(tracer.runs) == 2
# Light test to verify that we're picking up information about the server side
# function being invoked via a callback.
assert tracer.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough"
)

assert tracer.runs[1].child_runs[0].name == "RunnableLambda"
assert (
tracer.runs[1].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough"
)
assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough"
assert tracer.runs[1].child_runs[0].name == "add_one_or_passthrough"

# Verify that each tracer gets its own run
tracer1 = FakeTracer()
Expand All @@ -672,19 +641,9 @@ async def test_abatch(async_remote_runnable: RemoteRunnable) -> None:
) == [2, 3]
assert len(tracer1.runs) == 1
assert len(tracer2.runs) == 1
# Light test to verify that we're picking up information about the server side
# function being invoked via a callback.
assert tracer1.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer1.runs[0].child_runs[0].extra["kwargs"]["name"]
== "add_one_or_passthrough"
)

assert tracer2.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer2.runs[0].child_runs[0].extra["kwargs"]["name"]
== "add_one_or_passthrough"
)
assert tracer1.runs[0].child_runs[0].name == "add_one_or_passthrough"
assert tracer2.runs[0].child_runs[0].name == "add_one_or_passthrough"


async def test_astream(async_remote_runnable: RemoteRunnable) -> None:
Expand Down Expand Up @@ -1202,6 +1161,7 @@ async def add_one(x: int) -> int:
{
"kwargs": {},
"outputs": 2,
"metadata": None,
"parent_run_id": None,
"tags": [],
"run_id": None,
Expand Down Expand Up @@ -1264,6 +1224,7 @@ async def add_one(x: int) -> int:
"kwargs": {},
"outputs": 2,
"parent_run_id": None,
"metadata": None,
"run_id": None,
"tags": [],
"type": "on_chain_end",
Expand All @@ -1288,6 +1249,7 @@ async def add_one(x: int) -> int:
"kwargs": {},
"outputs": 3,
"parent_run_id": None,
"metadata": None,
"run_id": None,
"tags": [],
"type": "on_chain_end",
Expand Down
Loading

0 comments on commit b60480e

Please sign in to comment.