From 404931aa3dd3e589bd1e02406295a2f7631ab9a7 Mon Sep 17 00:00:00 2001 From: Isaac Wasserman Date: Wed, 22 Jan 2025 18:35:13 -0500 Subject: [PATCH] unpack tool output --- src/langchain_mcp/toolkit.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/langchain_mcp/toolkit.py b/src/langchain_mcp/toolkit.py index dc0da0a..b49b642 100644 --- a/src/langchain_mcp/toolkit.py +++ b/src/langchain_mcp/toolkit.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import asyncio +import json import warnings from collections.abc import Callable @@ -104,9 +105,10 @@ class MCPTool(BaseTool): session: ClientSession handle_tool_error: bool | str | Callable[[ToolException], str] | None = True + response_format: str = "content_and_artifact" @t.override - def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple[str, t.Any]: warnings.warn( "Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.", stacklevel=1, @@ -114,12 +116,16 @@ def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: return asyncio.run(self._arun(*args, **kwargs)) @t.override - async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Tuple[str, t.Any]: result = await self.session.call_tool(self.name, arguments=kwargs) content = pydantic_core.to_json(result.content).decode() if result.isError: raise ToolException(content) - return content + content_blocks = json.loads(content) + text_content = "\n".join([block["text"] for block in content_blocks if block["type"] == "text"]) + artifact = [block["artifact"] for block in content_blocks if "artifact" in block] + artifact = None if not artifact else (artifact[0] if len(artifact) == 1 else artifact) + return text_content, artifact @t.override @property