Skip to content

Commit

Permalink
adds sample sse-based mcp-server and client support in Codespace Assi…
Browse files Browse the repository at this point in the history
…stant (#312)
  • Loading branch information
bkrabach authored Jan 30, 2025
1 parent 99ca554 commit 2561ea3
Show file tree
Hide file tree
Showing 22 changed files with 825 additions and 63 deletions.
2 changes: 2 additions & 0 deletions assistants/codespace-assistant/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@
"cSpell.words": [
"Codespaces",
"contentsafety",
"debugpy",
"deepmerge",
"devcontainer",
"dotenv",
"endregion",
"Excalidraw",
"fastapi",
"GIPHY",
"jsonschema",
"Langchain",
"modelcontextprotocol",
Expand Down
13 changes: 13 additions & 0 deletions assistants/codespace-assistant/assistant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class AssistantConfigModel(BaseModel):
- Create core files and folders for the project as needed, such as README.md, .gitignore, etc.
- Create language specific files and folders as needed, such as package.json, pyproject.toml, etc.
- Files should include a newline at the end of the file.
- Provide instruction for the user on installing dependencies via cli instead of writing these
directly to the project files, this will ensure the user has the most up-to-date versions.
- Offer to keep the README and other documentation up-to-date with the latest project information, if
Expand All @@ -86,6 +87,18 @@ class AssistantConfigModel(BaseModel):
- Use 'pnpm' for managing dependencies (do not use 'npm' or 'yarn')
- It is ok to update '.vscode' folder contents and 'package.json' scripts as needed for adding run
and debug configurations, but do not add or remove any other files or folders.
- Consider the following strategy to improve approachability for both
developers and any AI assistants:
- Modularity and Conciseness: Each code file should not exceed one page in length, ensuring concise
and focused code. When a file exceeds one page, consider breaking it into smaller, more focused
files. Individual functions should be easily readable, wrapping larger blocks of code in functions
with clear names and purposes.
- Semantic Names: Use meaningful names for functions and modules to enhance understanding and
maintainability. These names will also be used for semantic searches by the AI assistant.
- Organized Structure: Maintain a well-organized structure, breaking down functionality into clear
and manageable components.
- Update Documentation: Keep documentation, including code comments, up-to-date with the latest
project information.
Ultimately, however, the user is in control of the project and can override the above guidance as needed.
""").strip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def get_mcp_server_configs(tools_config: ToolsConfigModel) -> List[MCPServerConf
b) Store facts about them as observations
"""),
),
MCPServerConfig(
name="GIPHY MCP Server",
command="http://127.0.0.1:6000/sse",
args=[],
),
# MCPServerConfig(
# name="Sequential Thinking MCP Server",
# command="npx",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,77 @@
from typing import AsyncIterator, List, Optional

from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client

from .__mcp_server_configs import get_mcp_server_configs
from .__model import MCPServerConfig, ToolsConfigModel
from .__model import MCPServerConfig, MCPSession, ToolsConfigModel

logger = logging.getLogger(__name__)


@asynccontextmanager
async def connect_to_mcp_server(server_config: MCPServerConfig) -> AsyncIterator[Optional[ClientSession]]:
"""Connect to a single MCP server defined in the config."""
if server_config.command.startswith("http"):
async with connect_to_mcp_server_sse(server_config) as client_session:
yield client_session
else:
async with connect_to_mcp_server_stdio(server_config) as client_session:
yield client_session


@asynccontextmanager
async def connect_to_mcp_server_stdio(server_config: MCPServerConfig) -> AsyncIterator[Optional[ClientSession]]:
"""Connect to a single MCP server defined in the config."""

server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
try:
logger.debug(
f"Attempting to connect to {server_config.name} with command: {server_config.command} {' '.join(server_config.args)}"
)
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session # Yield the session for use
async with ClientSession(read_stream, write_stream) as client_session:
await client_session.initialize()
yield client_session # Yield the session for use
except Exception as e:
logger.exception(f"Error connecting to {server_config.name}: {e}")
yield None # Yield None if connection fails


async def establish_mcp_sessions(tools_config: ToolsConfigModel, stack: AsyncExitStack) -> List[ClientSession]:
@asynccontextmanager
async def connect_to_mcp_server_sse(server_config: MCPServerConfig) -> AsyncIterator[Optional[ClientSession]]:
"""Connect to a single MCP server defined in the config using SSE transport."""

try:
logger.debug(f"Attempting to connect to {server_config.name} with SSE transport: {server_config.command}")
async with sse_client(url=server_config.command, headers=server_config.env) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as client_session:
await client_session.initialize()
yield client_session # Yield the session for use
except Exception as e:
logger.exception(f"Error connecting to {server_config.name}: {e}")
yield None


async def establish_mcp_sessions(tools_config: ToolsConfigModel, stack: AsyncExitStack) -> List[MCPSession]:
"""
Establish connections to MCP servers using the provided AsyncExitStack.
"""

sessions: List[ClientSession] = []
mcp_sessions: List[MCPSession] = []
for server_config in get_mcp_server_configs(tools_config):
session: ClientSession | None = await stack.enter_async_context(connect_to_mcp_server(server_config))
if session:
sessions.append(session)
client_session: ClientSession | None = await stack.enter_async_context(connect_to_mcp_server(server_config))
if client_session:
# Create an MCP session with the client session
mcp_session = MCPSession(name=server_config.name, client_session=client_session)
# Initialize the session to load tools, resources, etc.
await mcp_session.initialize()
# Add the session to the list of established sessions
mcp_sessions.append(mcp_session)
else:
logger.warning(f"Could not establish session with {server_config.name}")
return sessions
return mcp_sessions


def get_mcp_server_prompts(tools_config: ToolsConfigModel) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,51 @@
from typing import List

import deepmerge
from mcp import ClientSession, Tool
from mcp.types import TextContent
from mcp import Tool
from mcp.types import EmbeddedResource, ImageContent, TextContent

from .__model import ToolCall, ToolCallResult, ToolMessageType
from .__model import MCPSession, ToolCall, ToolCallResult, ToolMessageType

logger = logging.getLogger(__name__)


async def retrieve_tools_from_sessions(sessions: List[ClientSession]) -> List[Tool]:
def retrieve_tools_from_sessions(mcp_sessions: List[MCPSession]) -> List[Tool]:
"""
Retrieve tools from all MCP sessions.
"""
all_tools: List[Tool] = []
for session in sessions:
try:
tools_response = await session.list_tools()
tools = tools_response.tools
all_tools.extend(tools)
logger.debug(f"Retrieved tools from session: {[tool.name for tool in tools]}")
except Exception as e:
logger.exception(f"Error retrieving tools from session: {e}")
return all_tools
return [tool for mcp_session in mcp_sessions for tool in mcp_session.tools]


async def handle_tool_call(
sessions: List[ClientSession],
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
all_mcp_tools: List[Tool],
method_metadata_key: str,
) -> ToolCallResult:
"""
Handle the tool call by invoking the appropriate tool and returning a ToolCallResult.
"""

# Initialize metadata
metadata = {}

tool = next((t for t in all_mcp_tools if t.name == tool_call.name), None)
if not tool:
# Find the tool and session from the full collection of sessions
mcp_session, tool = next(
(
(mcp_session, tool)
for mcp_session in mcp_sessions
for tool in mcp_session.tools
if tool.name == tool_call.name
),
(None, None),
)
if not mcp_session or not tool:
return ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found.",
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)

target_session = next(
(session for session in sessions if tool_call.name in [tool.name for tool in all_mcp_tools]), None
)

if not target_session:
raise ValueError(f"Tool '{tool_call.name}' not found in any of the sessions.")

# Update metadata with tool call details
deepmerge.always_merger.merge(
metadata,
Expand All @@ -69,15 +62,17 @@ async def handle_tool_call(

# Initialize tool_result
tool_result = None
tool_output: list[TextContent | ImageContent | EmbeddedResource] = []
content_items: List[str] = []

# Invoke the tool
try:
logger.debug(f"Invoking tool '{tool_call.name}' with arguments: {tool_call.arguments}")
tool_result = await target_session.call_tool(tool_call.name, tool_call.arguments)
tool_output = tool_result.content[0] if tool_result.content else ""
logger.debug(f"Invoking '{mcp_session.name}.{tool_call.name}' with arguments: {tool_call.arguments}")
tool_result = await mcp_session.client_session.call_tool(tool_call.name, tool_call.arguments)
tool_output = tool_result.content
except Exception as e:
logger.exception(f"Error executing tool '{tool_call.name}': {e}")
tool_output = f"An error occurred while executing the tool '{tool_call.to_json()}': {e}"
content_items.append(f"An error occurred while executing the tool '{tool_call.to_json()}': {e}")

# Update metadata with tool result
deepmerge.always_merger.merge(
Expand All @@ -91,18 +86,18 @@ async def handle_tool_call(
},
)

# Return the tool call result
content: str | None = None
if isinstance(tool_output, str):
content = tool_output

if isinstance(tool_output, TextContent):
content = tool_output.text
for tool_output_item in tool_output:
if isinstance(tool_output_item, TextContent):
content_items.append(tool_output_item.text)
if isinstance(tool_output_item, ImageContent):
content_items.append(tool_output_item.model_dump_json())
if isinstance(tool_output_item, EmbeddedResource):
content_items.append(tool_output_item.model_dump_json())

# Return the tool call result
return ToolCallResult(
id=tool_call.id,
content=content or "Error executing tool, unsupported output type.",
content="\n\n".join(content_items),
message_type=ToolMessageType.tool_result,
metadata=metadata,
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import json
import logging
from enum import StrEnum
from textwrap import dedent
from typing import Annotated, Any, List, Optional

from attr import dataclass
from mcp import ClientSession, Tool
from pydantic import BaseModel, Field
from semantic_workbench_assistant.config import UISchema

logger = logging.getLogger(__name__)


@dataclass
class MCPServerConfig:
Expand All @@ -17,6 +21,22 @@ class MCPServerConfig:
prompt: Optional[str] = None


class MCPSession:
name: str
client_session: ClientSession
tools: List[Tool] = []

def __init__(self, name: str, client_session: ClientSession) -> None:
self.name = name
self.client_session = client_session

async def initialize(self) -> None:
# Load all tools from the session, later we can do the same for resources, prompts, etc.
tools_result = await self.client_session.list_tools()
self.tools = tools_result.tools
logger.debug(f"Loaded {len(tools_result.tools)} tools from session '{self.name}'")


@dataclass
class ToolCall:
id: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import deepmerge
import openai_client
from mcp import ClientSession, Tool
from openai.types.chat import (
ChatCompletion,
ChatCompletionToolMessageParam,
Expand All @@ -18,6 +17,8 @@
)
from semantic_workbench_assistant.assistant_app import ConversationContext

from assistant.extensions.tools.__model import MCPSession

from ..config import AssistantConfigModel
from ..extensions.tools import (
ToolCall,
Expand All @@ -36,8 +37,7 @@
async def handle_completion(
step_result: StepResult,
completion: ParsedChatCompletion | ChatCompletion,
mcp_sessions: List[ClientSession],
mcp_tools: List[Tool],
mcp_sessions: List[MCPSession],
context: ConversationContext,
config: AssistantConfigModel,
silence_token: str,
Expand Down Expand Up @@ -166,7 +166,6 @@ async def handle_error(error_message: str) -> StepResult:
tool_call_result = await handle_tool_call(
mcp_sessions,
tool_call,
mcp_tools,
f"{metadata_key}:request:tool_call_{tool_call_count}",
)
except Exception as e:
Expand Down
10 changes: 3 additions & 7 deletions assistants/codespace-assistant/assistant/response/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from typing import Any, List

from assistant_extensions.attachments import AttachmentsExtension
from mcp import ClientSession
from semantic_workbench_api_model.workbench_model import (
MessageType,
NewConversationMessage,
)
from semantic_workbench_assistant.assistant_app import ConversationContext

from assistant.extensions.tools.__model import MCPSession

from ..config import AssistantConfigModel
from ..extensions.tools import (
establish_mcp_sessions,
get_mcp_server_prompts,
retrieve_tools_from_sessions,
)
from .step_handler import next_step

Expand All @@ -34,7 +34,7 @@ async def respond_to_conversation(

async with AsyncExitStack() as stack:
# If tools are enabled, establish connections to the MCP servers
mcp_sessions: List[ClientSession] = []
mcp_sessions: List[MCPSession] = []
if config.extensions_config.tools.enabled:
mcp_sessions = await establish_mcp_sessions(config.extensions_config.tools, stack)
if not mcp_sessions:
Expand All @@ -50,9 +50,6 @@ async def respond_to_conversation(
# Retrieve prompts from the MCP servers
mcp_prompts = get_mcp_server_prompts(config.extensions_config.tools)

# Retrieve tools from the MCP sessions
mcp_tools = await retrieve_tools_from_sessions(mcp_sessions)

# Initialize a loop control variable
max_steps = config.extensions_config.tools.max_steps
interrupted = False
Expand All @@ -77,7 +74,6 @@ async def respond_to_conversation(

step_result = await next_step(
mcp_sessions=mcp_sessions,
mcp_tools=mcp_tools,
mcp_prompts=mcp_prompts,
attachments_extension=attachments_extension,
context=context,
Expand Down
Loading

0 comments on commit 2561ea3

Please sign in to comment.