diff --git a/libs/standard-tests/langchain_tests/integration_tests/tools.py b/libs/standard-tests/langchain_tests/integration_tests/tools.py index 2fcd610ccc052..86a26338ed984 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/tools.py +++ b/libs/standard-tests/langchain_tests/integration_tests/tools.py @@ -29,15 +29,10 @@ def test_invoke_matches_output_schema(self, tool: BaseTool) -> None: ) result = tool.invoke(tool_call) - if tool.response_format == "content": - tool_message = result - elif tool.response_format == "content_and_artifact": - # should be (content, artifact) - assert isinstance(result, tuple) - assert len(result) == 2 - tool_message, artifact = result - - assert artifact # artifact can be anything, but shouldn't be none + tool_message = result + if tool.response_format == "content_and_artifact": + # artifact can be anything, except none + assert tool_message.artifact is not None # check content is a valid ToolMessage content assert isinstance(tool_message.content, (str, list)) @@ -59,15 +54,10 @@ async def test_async_invoke_matches_output_schema(self, tool: BaseTool) -> None: ) result = await tool.ainvoke(tool_call) - if tool.response_format == "content": - tool_message = result - elif tool.response_format == "content_and_artifact": - # should be (content, artifact) - assert isinstance(result, tuple) - assert len(result) == 2 - tool_message, artifact = result - - assert artifact # artifact can be anything, but shouldn't be none + tool_message = result + if tool.response_format == "content_and_artifact": + # artifact can be anything, except none + assert tool_message.artifact is not None # check content is a valid ToolMessage content assert isinstance(tool_message.content, (str, list)) diff --git a/libs/standard-tests/tests/unit_tests/test_basic_tool.py b/libs/standard-tests/tests/unit_tests/test_basic_tool.py index 046e15d760662..d8aa16504e4c9 100644 --- a/libs/standard-tests/tests/unit_tests/test_basic_tool.py +++ b/libs/standard-tests/tests/unit_tests/test_basic_tool.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Literal, Type from langchain_core.tools import BaseTool @@ -16,6 +16,17 @@ def _run(self, a: int, b: int) -> int: return a * b + 80 +class ParrotMultiplyArtifactTool(BaseTool): # type: ignore + name: str = "ParrotMultiplyArtifactTool" + description: str = ( + "Multiply two numbers like a parrot. Parrots always add eighty for their matey." + ) + response_format: Literal["content_and_artifact"] = "content_and_artifact" + + def _run(self, a: int, b: int) -> tuple[int, str]: + return a * b + 80, "parrot artifact" + + class TestParrotMultiplyToolUnit(ToolsUnitTests): @property def tool_constructor(self) -> Type[ParrotMultiplyTool]: @@ -60,3 +71,26 @@ def tool_invoke_params_example(self) -> dict: have {"name", "id", "args"} keys. """ return {"a": 2, "b": 3} + + +class TestParrotMultiplyArtifactToolIntegration(ToolsIntegrationTests): + @property + def tool_constructor(self) -> Type[ParrotMultiplyArtifactTool]: + return ParrotMultiplyArtifactTool + + @property + def tool_constructor_params(self) -> dict: + # if your tool constructor instead required initialization arguments like + # `def __init__(self, some_arg: int):`, you would return those here + # as a dictionary, e.g.: `return {'some_arg': 42}` + return {} + + @property + def tool_invoke_params_example(self) -> dict: + """ + Returns a dictionary representing the "args" of an example tool call. + + This should NOT be a ToolCall dict - i.e. it should not + have {"name", "id", "args"} keys. + """ + return {"a": 2, "b": 3}