diff --git a/gai-backend/config_manager.py b/gai-backend/config_manager.py index f07b51c9b..d7aa9d608 100644 --- a/gai-backend/config_manager.py +++ b/gai-backend/config_manager.py @@ -3,6 +3,9 @@ import json import time import os +import logging + +logger = logging.getLogger("config_manager") class ConfigError(Exception): """Raised when config operations fail""" @@ -118,28 +121,54 @@ async def write_config(self, config: Dict[str, Any], force: bool = False): async def load_config(self, config_path: Optional[str] = None, force_reload: bool = False) -> Dict[str, Any]: try: + logger.debug(f"Loading config - path: {config_path}, force_reload: {force_reload}") + if config_path: + logger.debug(f"Loading config from file: {config_path}") config = await self.load_from_file(config_path) await self.write_config(config, force=True) + logger.debug("Loaded and wrote config from file") return config timestamp = await self.redis.get("config:last_update") + logger.debug(f"Redis config timestamp: {timestamp}, last_load_time: {self.last_load_time}") if not force_reload and timestamp and self.last_load_time >= float(timestamp): + logger.debug("Using cached config (no changes detected)") return self.current_config config_data = await self.redis.get("config:data") if not config_data: - raise ConfigError("No configuration found in Redis") + logger.warning("No configuration found in Redis") + # Provide a minimal default config instead of raising an error + default_config = { + "inference": { + "endpoints": {}, + "tools": { + "enabled": False, + "inject_defaults": False + } + } + } + logger.debug("Using minimal default configuration") + self.current_config = default_config + return default_config + logger.debug("Loading config from Redis") config = json.loads(config_data) config = self.process_config(config) + if "tools" in config.get("inference", {}): + tools_cfg = config["inference"]["tools"] + logger.debug(f"Tools config loaded - enabled: {tools_cfg.get('enabled')}, inject_defaults: {tools_cfg.get('inject_defaults')}") + self.current_config = config self.last_load_time = float(timestamp) if timestamp else time.time() + logger.debug("Successfully loaded config from Redis") return config except Exception as e: + logger.error(f"Failed to load config: {e}", exc_info=True) raise ConfigError(f"Failed to load config: {e}") async def check_for_updates(self) -> bool: diff --git a/gai-backend/inference_adapters.py b/gai-backend/inference_adapters.py index e7555687d..7237eb245 100644 --- a/gai-backend/inference_adapters.py +++ b/gai-backend/inference_adapters.py @@ -19,6 +19,8 @@ logger = logging.getLogger(__name__) class ModelAdapter: + logger = logging.getLogger(__name__) + @staticmethod def parse_response(api_type: str, response: Dict[str, Any], model: str, request_id: str) -> ChatCompletion: if api_type in ('openai', 'openrouter'): @@ -72,8 +74,8 @@ def parse_anthropic_response(response: Dict[str, Any], model: str, response_id: system_fingerprint=response.get('system_fingerprint') ) - @staticmethod - def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: + @classmethod + def prepare_anthropic_request(cls, model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: headers = { "Content-Type": "application/json", "x-api-key": api_key, @@ -129,6 +131,7 @@ def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, reques # Handle tools tools = request.get_effective_tools() if tools: + cls.logger.debug(f"Including {len(tools)} tools in Anthropic request") data["tools"] = [] for tool in tools: tool_def = { @@ -137,8 +140,9 @@ def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, reques "input_schema": tool.function.parameters } data["tools"].append(tool_def) - - # Handle tool choice + cls.logger.debug(f"Tools in Anthropic request: {json.dumps(data['tools'])}") + + # Handle tool choice - only when tools are present tool_choice = request.get_effective_tool_choice() if isinstance(tool_choice, str): data["tool_choice"] = "none" if tool_choice == "none" else {"type": "auto"} @@ -149,6 +153,8 @@ def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, reques } else: data["tool_choice"] = {"type": "auto"} + else: + cls.logger.debug("No tools to include in Anthropic request") # Copy other parameters if request.temperature is not None: @@ -188,8 +194,8 @@ def parse_openai_response(response: Dict[str, Any], model: str, response_id: str ) ) - @staticmethod - def prepare_openai_request(model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: + @classmethod + def prepare_openai_request(cls, model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" @@ -198,6 +204,15 @@ def prepare_openai_request(model_config: Dict[str, Any], api_key: str, request: data = request.dict(exclude_none=True, exclude={'request_id'}) data['model'] = model_config['id'] + # Log tools information + tools = request.get_effective_tools() + if tools: + cls.logger.debug(f"Including {len(tools)} tools in OpenAI request") + for tool in tools: + cls.logger.debug(f"Tool in OpenAI request: {tool.function.name}") + else: + cls.logger.debug("No tools to include in OpenAI request") + return data, headers @classmethod diff --git a/gai-backend/inference_api.py b/gai-backend/inference_api.py index 312edd18b..44821cdfd 100644 --- a/gai-backend/inference_api.py +++ b/gai-backend/inference_api.py @@ -98,17 +98,23 @@ async def chat_completion( @app.get("/v1/models") async def list_openai_models(): try: + logger.debug("Handling request to /v1/models") return await api.list_openai_models() except Exception as e: + logger.error(f"Error listing OpenAI models: {e}", exc_info=True) if isinstance(e, InferenceError): raise - raise HTTPException(status_code=500, detail=str(e)) + # Return empty list instead of error + return {"data": []} @app.get("/v1/inference/models") async def list_inference_models(): try: + logger.debug("Handling request to /v1/inference/models") return await api.list_models() except Exception as e: + logger.error(f"Error listing inference models: {e}", exc_info=True) if isinstance(e, InferenceError): raise - raise HTTPException(status_code=500, detail=str(e)) + # Return empty dict instead of error + return {} diff --git a/gai-backend/inference_core.py b/gai-backend/inference_core.py index a97e46540..6cb06738a 100644 --- a/gai-backend/inference_core.py +++ b/gai-backend/inference_core.py @@ -1,8 +1,12 @@ from redis.asyncio import Redis import json import requests +import copy from typing import Dict, Any, Tuple, AsyncGenerator, Union from datetime import datetime +import subprocess +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client from inference_models import ( ChatCompletionRequest, @@ -15,7 +19,8 @@ OpenAIModel, OpenAIModelList, Tool, - ToolChoice + ToolChoice, + FunctionDefinition ) from inference_errors import ( InferenceError, @@ -57,10 +62,252 @@ def __init__(self, redis: Redis): self.redis = redis self.config_manager = ConfigManager(redis) self.billing = StrictRedisBilling(redis) + self.mcp_session = None async def init(self): await self.billing.init() await self.config_manager.load_config() + + # Initialize tool system + from tool_registry import ToolRegistry + from tool_executor import ToolExecutor + + # Create tool registry and executor + self.tool_registry = ToolRegistry() + + # Set initial tool state + self.tools_initialized = False + self.tools_initialization_error = None + + # Get configuration + config = await self.config_manager.load_config() + + # Initialize basic components needed for models API to work + self.tool_executor = ToolExecutor(self.billing, config) + + # Start async initialization of tool servers (non-blocking) + import asyncio + logger.info("Starting asynchronous tool initialization") + asyncio.create_task(self._init_tools_async(config)) + + async def _init_tools_async(self, config): + """Initialize tool servers asynchronously to avoid blocking API startup""" + try: + # Set initialized to False at the start of initialization + self.tools_initialized = False + self.tools_initialization_error = None + + # Log key config information at debug level + tools_config = config.get('inference', {}).get('tools', {}) + tools_enabled = tools_config.get('enabled', False) + inject_defaults = tools_config.get('inject_defaults', False) + + logger.debug(f"Tools config: enabled={tools_enabled}, inject_defaults={inject_defaults}") + + # Connect to MCP servers for tools + if tools_enabled: + logger.debug("Initializing tools from configuration") + import asyncio + + # Create a list to hold all server connection tasks + server_tasks = [] + mcp_servers = tools_config.get('mcp_servers', {}) + + # Log the number of MCP servers to be initialized + logger.info(f"Found {len(mcp_servers)} MCP servers in configuration") + + # Start each MCP server connection in sequence (avoid race conditions) + # This is a key change from the previous approach to ensure reliable initialization + for server_id, server_config in mcp_servers.items(): + try: + # Get environment variables from parent process and config + import os + + # Start with parent environment + combined_env = dict(os.environ) + + # Add config environment variables (will override parent if same keys) + config_env = server_config.get('env', {}) + combined_env.update(config_env) + + # Ensure we have EXA_API_KEY if in parent environment + if 'EXA_API_KEY' in os.environ and ('EXA_API_KEY' not in config_env or + 'your' in config_env.get('EXA_API_KEY', '').lower()): + logger.info(f"Using EXA_API_KEY from parent environment for {server_id}") + combined_env['EXA_API_KEY'] = os.environ['EXA_API_KEY'] + + # Log only critical environment variables + logger.debug(f"Environment for {server_id} prepared with {len(combined_env)} variables") + + server_params = StdioServerParameters( + command=server_config.get('command', 'node'), + args=server_config.get('args', []), + env=combined_env + ) + + # Connect to this server directly (not as a task) + # This ensures servers are initialized one at a time, avoiding race conditions + logger.info(f"Connecting to MCP server: {server_id}") + connection_success = await self._connect_mcp_server(server_id, server_params) + + # Log connection status + if connection_success: + logger.info(f"Successfully connected to MCP server: {server_id}") + else: + logger.error(f"Failed to connect to MCP server: {server_id}") + + # Add a small delay between server initializations + await asyncio.sleep(1.0) + + except Exception as e: + logger.error(f"Failed to initialize MCP server {server_id}: {e}", exc_info=True) + + # Initialize tools from configuration with an increased timeout + try: + # Add timeout for tool initialization + async with asyncio.timeout(30): + logger.debug("Starting tool registry initialization from config") + await self.tool_registry.init_from_config(config) + logger.debug("Completed tool registry initialization") + except asyncio.TimeoutError: + logger.error("Timeout during tool registry initialization after 30 seconds") + self.tools_initialization_error = "Timeout during tool registry initialization" + except Exception as e: + logger.error(f"Error during tool registry initialization: {e}", exc_info=True) + self.tools_initialization_error = str(e) + + # Log registered tools + tools = self.tool_registry.list_tools() + logger.debug(f"Registered tools: {tools}") + + if not tools: + logger.warning("No tools were registered! Tool injection will not work.") + else: + logger.info(f"Successfully registered {len(tools)} tools: {', '.join(tools)}") + + # Mark tools as successfully initialized (even if there was an error) + # This allows the server to continue operating without tools if needed + self.tools_initialized = True + logger.info("Tool initialization completed") + else: + # Tools not enabled, mark as initialized to avoid waiting + self.tools_initialized = True + logger.info("Tools are disabled in configuration") + except Exception as e: + logger.error(f"Error during async tool initialization: {e}", exc_info=True) + self.tools_initialization_error = str(e) + # Mark as initialized but with an error + self.tools_initialized = True + + async def _connect_mcp_server(self, server_id, server_params): + """Connect to a single MCP server using direct subprocess communication""" + try: + import asyncio + import os + from contextlib import AsyncExitStack + + # Detailed environment logging for debugging + env = server_params.env or {} + has_api_key = False + + # Check for API key presence - don't show the actual key + for k, v in env.items(): + if k.upper() == 'EXA_API_KEY' and v and 'your' not in v.lower(): + has_api_key = True + logger.debug(f"Found EXA_API_KEY in environment for {server_id}") + + if not has_api_key and server_id == 'exa-mcp': + logger.warning(f"No valid EXA_API_KEY found for {server_id} - connection will likely fail") + + # Check if it's in the parent environment + if 'EXA_API_KEY' in os.environ and 'your' not in os.environ['EXA_API_KEY'].lower(): + logger.info(f"Found EXA_API_KEY in parent environment, will use that for {server_id}") + has_api_key = True + + # Make sure the environment is updated + if isinstance(server_params.env, dict): + server_params.env['EXA_API_KEY'] = os.environ['EXA_API_KEY'] + else: + server_params.env = {'EXA_API_KEY': os.environ['EXA_API_KEY']} + + # Use the direct subprocess approach + try: + # Get command and args + command = server_params.command + args = server_params.args + env_dict = dict(server_params.env) if server_params.env else dict(os.environ) + + logger.info(f"Starting MCP server {server_id} with command: {command} {' '.join(args)}") + + # Import our direct subprocess implementation + from mcp_wrapper import create_mcp_server, MCPSession + + # Use the context manager to create and manage server process + stack = AsyncExitStack() + process = await stack.enter_async_context(create_mcp_server(command, args, env=env_dict)) + + # Create and initialize session + logger.info(f"Creating MCP session for {server_id}") + session = await MCPSession.create(process) + logger.info(f"Session created and initialized for {server_id}") + + # Test connection by listing tools + logger.info(f"Listing tools for {server_id}") + tools_response = await session.list_tools() + + if tools_response: + # Handle different return types from the MCPSession.list_tools method + if hasattr(tools_response, 'tools'): + # New style response (ToolsResponse object) + tools_list = tools_response.tools + tool_count = len(tools_list) + + # Log tool information + logger.info(f"MCP server {server_id} has {tool_count} tools available") + for tool in tools_list: + # Handle both attribute and dictionary style access + if hasattr(tool, 'name') and hasattr(tool, 'description'): + logger.info(f"Tool: {tool.name} - {tool.description}") + else: + logger.info(f"Tool: {tool.get('name')} - {tool.get('description')}") + + elif isinstance(tools_response, list): + # Old style response (list of dictionaries) + tool_count = len(tools_response) + + # Log tool information + logger.info(f"MCP server {server_id} has {tool_count} tools available") + for tool in tools_response: + logger.info(f"Tool: {tool.get('name')} - {tool.get('description')}") + + else: + # Unexpected response type + logger.warning(f"Unexpected tools response type: {type(tools_response)}") + await stack.aclose() + return False + + # Register the session + self.tool_registry.register_mcp_session(server_id, session) + logger.info(f"Successfully connected and registered MCP server: {server_id}") + + # Store the exit stack for proper cleanup later + if not hasattr(self, '_mcp_stacks'): + self._mcp_stacks = {} + self._mcp_stacks[server_id] = stack + + return True + else: + logger.warning(f"No tools found for MCP server {server_id}") + await stack.aclose() + return False + + except Exception as e: + logger.error(f"Error connecting to MCP server {server_id}: {e}", exc_info=True) + return False + + except Exception as e: + logger.error(f"Failed to connect to MCP server {server_id}: {e}", exc_info=True) + return False async def validate_session(self, session_id: str) -> None: if not session_id: @@ -138,6 +385,18 @@ def query_backend(self, endpoint_config: Dict[str, Any], model_config: Dict[str, try: data, headers = ModelAdapter.prepare_request(endpoint_config, model_config, request) + # Log the prepared request data for debugging (redact API keys) + sanitized_headers = {k: v if k.lower() != 'authorization' and 'key' not in k.lower() else '[REDACTED]' + for k, v in headers.items()} + logger.debug(f"Backend request - URL: {endpoint_config['url']}") + logger.debug(f"Backend request - Headers: {sanitized_headers}") + + # Check for tools in prepared request data + if 'tools' in data: + logger.debug(f"Backend request contains {len(data['tools'])} tools") + elif request.tools: + logger.warning("Request had tools but they were not included in backend request!") + response = requests.post( endpoint_config['url'], headers=headers, @@ -167,34 +426,72 @@ def query_backend(self, endpoint_config: Dict[str, Any], model_config: Dict[str, raise BackendServiceError() async def list_models(self) -> Dict[str, ModelInfo]: - config = await self.config_manager.load_config() - models = {} + try: + config = await self.config_manager.load_config() + models = {} - for endpoint_id, endpoint in config['inference']['endpoints'].items(): - for model in endpoint['models']: - models[model['id']] = ModelInfo( - id=model['id'], - name=model.get('display_name', model['id']), - api_type=endpoint['api_type'], - endpoint=endpoint_id - ) + # Safely access endpoints if they exist + endpoints = config.get('inference', {}).get('endpoints', {}) + if not endpoints: + logger.warning("No endpoints found in configuration") + return {} + + for endpoint_id, endpoint in endpoints.items(): + # Make sure the endpoint has the required fields + if 'api_type' not in endpoint or 'models' not in endpoint: + logger.warning(f"Endpoint {endpoint_id} is missing required fields") + continue + + for model in endpoint['models']: + if 'id' not in model: + logger.warning(f"Model in endpoint {endpoint_id} is missing id") + continue + + models[model['id']] = ModelInfo( + id=model['id'], + name=model.get('display_name', model['id']), + api_type=endpoint['api_type'], + endpoint=endpoint_id + ) - return models + return models + except Exception as e: + logger.error(f"Error fetching models list: {e}", exc_info=True) + return {} # Return empty dict instead of failing async def list_openai_models(self) -> OpenAIModelList: - config = await self.config_manager.load_config() - models = [] - created = int(datetime.now().timestamp()) + try: + config = await self.config_manager.load_config() + models = [] + created = int(datetime.now().timestamp()) - for endpoint_id, endpoint in config['inference']['endpoints'].items(): - for model in endpoint['models']: - models.append(OpenAIModel( - id=model['id'], - created=model.get('created', created), - owned_by=endpoint.get('provider', 'orchid-labs') - )) + # Safely access endpoints if they exist + endpoints = config.get('inference', {}).get('endpoints', {}) + if not endpoints: + logger.warning("No endpoints found in configuration") + return OpenAIModelList(data=[]) - return OpenAIModelList(data=models) + for endpoint_id, endpoint in endpoints.items(): + # Make sure the endpoint has models + if 'models' not in endpoint: + logger.warning(f"Endpoint {endpoint_id} has no models") + continue + + for model in endpoint['models']: + if 'id' not in model: + logger.warning(f"Model in endpoint {endpoint_id} is missing id") + continue + + models.append(OpenAIModel( + id=model['id'], + created=model.get('created', created), + owned_by=endpoint.get('provider', 'orchid-labs') + )) + + return OpenAIModelList(data=models) + except Exception as e: + logger.error(f"Error fetching OpenAI models list: {e}", exc_info=True) + return OpenAIModelList(data=[]) # Return empty list instead of failing async def create_stream_chunks(self, completion: ChatCompletion) -> AsyncGenerator[str, None]: first_chunk = ChatCompletionChunk( @@ -228,18 +525,26 @@ async def create_stream_chunks(self, completion: ChatCompletion) -> AsyncGenerat yield f"data: {json.dumps(tool_chunk.dict(exclude_none=True))}\n\n" if message.content: - content_chunk = ChatCompletionChunk( - id=completion.id, - model=completion.model, - choices=[ - ChatChoice( - index=choice.index, - delta={"content": message.content}, - finish_reason=None - ) - ] - ) - yield f"data: {json.dumps(content_chunk.dict(exclude_none=True))}\n\n" + # For content that was concatenated from multiple completions, + # we split by newline and stream each part separately + content_parts = message.content.split("\n") + + for part in content_parts: + if not part: # Skip empty parts + continue + + content_chunk = ChatCompletionChunk( + id=completion.id, + model=completion.model, + choices=[ + ChatChoice( + index=choice.index, + delta={"content": part + "\n"}, # Add newline to maintain formatting + finish_reason=None + ) + ] + ) + yield f"data: {json.dumps(content_chunk.dict(exclude_none=True))}\n\n" final_chunk = ChatCompletionChunk( id=completion.id, @@ -273,62 +578,273 @@ async def stream_inference(self, request: ChatCompletionRequest, session_id: str async def handle_inference(self, request: ChatCompletionRequest, session_id: str) -> ChatCompletion: try: - logger.debug(f"Starting inference handling for session {session_id}") - endpoint_config, model_config = await self.get_model_config(request.model) - logger.debug(f"Retrieved model config for {request.model}") - - tools = request.get_effective_tools() - if tools and len(tools) > 128: - raise ValidationError("Maximum of 128 tools allowed") - - input_tokens = self.count_input_tokens(request) - max_output_tokens = request.max_tokens or model_config.get('params', {}).get( - 'max_tokens', - endpoint_config.get('params', {}).get('max_tokens', 4096) - ) + # Add retry mechanism for tool initialization + max_retries = 1 + retry_count = 0 + retry_delay = 0.5 # seconds - max_cost = self.calculate_cost( - model_config['pricing'], - input_tokens, - max_output_tokens - ) - - logger.debug(f"Calculated max cost: {max_cost}") - - balance = await self.billing.balance(session_id) - if balance < max_cost: - await self.redis.publish( - f"billing:balance:updates:{session_id}", - str(balance) - ) - raise InsufficientBalanceError() + while retry_count <= max_retries: + try: + logger.debug(f"Starting inference handling for session {session_id}") + if retry_count > 0: + logger.info(f"Retry attempt {retry_count} for request {request.request_id}") + endpoint_config, model_config = await self.get_model_config(request.model) + + # Inject tools if configured and no tools are specified + config = await self.config_manager.load_config() + logger.debug(f"Handle inference - Tools config: {json.dumps(config.get('inference', {}).get('tools', {}))}") + + tools_config = config.get('inference', {}).get('tools', {}) + tools_enabled = tools_config.get('enabled', False) + inject_defaults = tools_config.get('inject_defaults', False) + + # Check if the request already has tools + if request.tools: + logger.debug(f"Request already has {len(request.tools)} tools defined, not injecting") + elif tools_enabled and inject_defaults: + # Check if tools are initialized + if not self.tools_initialized: + logger.warning("Tools are still initializing, skipping tool injection") + retry_count += 1 + if retry_count <= max_retries: + logger.info(f"Waiting {retry_delay}s before retry {retry_count}/{max_retries}") + import asyncio + await asyncio.sleep(retry_delay) + continue + else: + logger.warning("Max retries reached waiting for tool initialization") + + # Check if there was an error during initialization + if self.tools_initialization_error: + logger.warning(f"Tool initialization had errors: {self.tools_initialization_error}") + + # Get available tools from registry + available_tools = self.tool_registry.get_available_tools() + logger.debug(f"Available tools count: {len(available_tools)}") + + # Log details about available tools + for tool in available_tools: + logger.debug(f"Tool to inject: {tool.function.name} - {tool.function.description}") + + if available_tools: + logger.debug("Injecting tools into request") + request.tools = available_tools + logger.debug(f"After injection, request has {len(request.tools)} tools") + else: + logger.warning("No tools available to inject!") + else: + logger.debug("Tool injection is disabled or not configured") + + # Calculate and validate costs + input_tokens = self.count_input_tokens(request) + max_output_tokens = request.max_tokens or model_config.get('params', {}).get( + 'max_tokens', + endpoint_config.get('params', {}).get('max_tokens', 4096) + ) + + max_cost = self.calculate_cost( + model_config['pricing'], + input_tokens, + max_output_tokens + ) + + logger.debug(f"Calculated max cost: {max_cost}") + + balance = await self.billing.balance(session_id) + if balance < max_cost: + await self.redis.publish( + f"billing:balance:updates:{session_id}", + str(balance) + ) + raise InsufficientBalanceError() + + await self.billing.debit(session_id, amount=max_cost) + logger.debug(f"Debited {max_cost} from session {session_id}") + + # Log the request before sending for debugging + if request.tools: + logger.debug(f"Sending request to backend with {len(request.tools)} tools") + for tool in request.tools: + logger.debug(f"Tool in request: {tool.function.name}") + + # Get initial completion which may include tool calls + initial_completion = self.query_backend(endpoint_config, model_config, request) + + # Check if the response includes tool calls + if initial_completion.choices and initial_completion.choices[0].message.tool_calls: + # Save the initial content from the model (may be None) + initial_content = initial_completion.choices[0].message.content + accumulated_content = [] if initial_content is None else [initial_content] + + # Clone the request for our follow-up + follow_up_request = copy.deepcopy(request) + + # Get the tool calls + tool_calls = initial_completion.choices[0].message.tool_calls + + # Add the assistant's initial response with tool calls + follow_up_request.messages.append(Message( + role="assistant", + content=initial_content, + tool_calls=tool_calls + )) + + # Process tool calls with the tool executor + execution_context = { + 'request_id': request.request_id, + 'model': request.model, + 'endpoint': endpoint_config['api_type'] + } + + # Execute all tools + tool_responses = await self.tool_executor.execute_all_tools( + session_id=session_id, + tool_calls=tool_calls, + context=execution_context + ) + + # Add tool responses to the follow-up request + for response in tool_responses: + follow_up_request.messages.append(response) + + # Track completion costs for final calculation + total_prompt_tokens = initial_completion.usage.prompt_tokens + total_completion_tokens = initial_completion.usage.completion_tokens + + # Process recursive tool calls with a loop - no fixed limit + # as long as the user keeps paying, we'll keep processing + iteration = 0 + reached_limit = False + max_consecutive_errors = 3 # Safety limit for errors + consecutive_errors = 0 + logger.info(f"Starting recursive tool execution for request {request.request_id}") + + # Continue processing as long as errors don't exceed the threshold + while consecutive_errors < max_consecutive_errors: + try: + # Get next completion with tool results + next_completion = self.query_backend(endpoint_config, model_config, follow_up_request) + + # Add costs + total_prompt_tokens += next_completion.usage.prompt_tokens + total_completion_tokens += next_completion.usage.completion_tokens + + # Reset error counter on successful completion + consecutive_errors = 0 + except Exception as e: + # Increment error counter and log the error + consecutive_errors += 1 + logger.warning(f"Error during tool iteration {iteration+1} (error {consecutive_errors}/{max_consecutive_errors}): {e}") + + if consecutive_errors < max_consecutive_errors: + continue + else: + # We'll handle this outside the loop + break + + # Check if this response includes more tool calls + if (next_completion.choices and next_completion.choices[0].message.tool_calls): + # Track the next tool calls + # Try to detect if we're in an infinite loop by tracking repeated identical tool calls + # Save any content from this response + content = next_completion.choices[0].message.content + if content: + accumulated_content.append(content) + + # Get new tool calls + tool_calls = next_completion.choices[0].message.tool_calls + + # Add the assistant's response to the conversation + follow_up_request.messages.append(Message( + role="assistant", + content=content, + tool_calls=tool_calls + )) + + # Execute the new tools + tool_responses = await self.tool_executor.execute_all_tools( + session_id=session_id, + tool_calls=tool_calls, + context=execution_context + ) + + # Add new tool responses + for response in tool_responses: + follow_up_request.messages.append(response) + + # Continue the loop + iteration += 1 + logger.info(f"Completed tool iteration {iteration} for request {request.request_id}") + else: + # No more tool calls, we're done + # Add the final content and break the loop + if next_completion.choices[0].message.content: + accumulated_content.append(next_completion.choices[0].message.content) + logger.info(f"Finished tool iterations (no more tool calls) after {iteration} iterations") + break + + # Handle consecutive errors - if we reached error limit + if consecutive_errors >= max_consecutive_errors: + logger.warning(f"Reached maximum consecutive errors ({max_consecutive_errors}) for request {request.request_id}") + # Add a user message asking for a summary of what was retrieved + follow_up_request.messages.append(Message( + role="user", + content="There appears to be an issue with continued tool execution. Please summarize what information you've retrieved so far." + )) + + # Get the final summary + final_summary = self.query_backend(endpoint_config, model_config, follow_up_request) + + # Add costs for the summary + total_prompt_tokens += final_summary.usage.prompt_tokens + total_completion_tokens += final_summary.usage.completion_tokens + + # Add the summary to accumulated content + if final_summary.choices and final_summary.choices[0].message.content: + logger.info(f"Received error summary of length {len(final_summary.choices[0].message.content)} for request {request.request_id}") + accumulated_content.append(final_summary.choices[0].message.content) + else: + logger.warning(f"Error summary has no content for request {request.request_id}") + + # Use this as our final completion + final_completion = final_summary + else: + logger.info(f"Tool execution completed normally after {iteration} iterations for request {request.request_id}") + # Create the final completion from the last completion + final_completion = next_completion + + # Join all accumulated content with newlines + final_content = "\n".join([c for c in accumulated_content if c]) + + # Update the final completion's content + if final_completion.choices and final_completion.choices[0].message: + final_completion.choices[0].message.content = final_content + + # Update usage statistics + final_completion.usage.prompt_tokens = total_prompt_tokens + final_completion.usage.completion_tokens = total_completion_tokens + final_completion.usage.total_tokens = total_prompt_tokens + total_completion_tokens + + return final_completion + + # If no tool calls, return the initial completion + return initial_completion - await self.billing.debit(session_id, amount=max_cost) - logger.debug(f"Debited {max_cost} from session {session_id}") + except Exception as e: + retry_count += 1 + if retry_count <= max_retries: + logger.warning(f"Error during inference (attempt {retry_count}/{max_retries+1}): {e}") + import asyncio + await asyncio.sleep(retry_delay) + continue + + logger.error(f"Error during inference: {str(e)}", exc_info=True) + if 'max_cost' in locals(): + await self.billing.credit(session_id, amount=max_cost) + raise - try: - logger.debug("Querying backend API") - result = self.query_backend(endpoint_config, model_config, request) - logger.debug(f"Received backend response: {result.dict(exclude_none=True)}") - - actual_cost = self.calculate_cost( - model_config['pricing'], - result.usage.prompt_tokens, - result.usage.completion_tokens - ) - - if actual_cost < max_cost: - refund = max_cost - actual_cost - await self.billing.credit(session_id, amount=refund) - logger.debug(f"Credited refund of {refund} to session {session_id}") - - return result - - except Exception as e: - logger.error(f"Error during backend query: {str(e)}", exc_info=True) - await self.billing.credit(session_id, amount=max_cost) - raise - + raise Exception("Unexpected end of retry loop in handle_inference") + except InferenceError: raise except Exception as e: diff --git a/gai-backend/inference_logging.py b/gai-backend/inference_logging.py index ebc9f4d55..df12192d1 100644 --- a/gai-backend/inference_logging.py +++ b/gai-backend/inference_logging.py @@ -25,7 +25,7 @@ "console": { "class": "logging.StreamHandler", "formatter": "colored", - "level": "INFO" + "level": "DEBUG" } }, "loggers": { @@ -41,6 +41,11 @@ "handlers": ["console"], "level": os.getenv("ORCHID_GENAI_INF_LOGLVL", "INFO"), "propagate": False + }, + "config_manager": { + "handlers": ["console"], + "level": os.getenv("ORCHID_GENAI_INF_LOGLVL", "INFO"), + "propagate": False } } } diff --git a/gai-backend/inference_models.py b/gai-backend/inference_models.py index b6118503e..9aca17e11 100644 --- a/gai-backend/inference_models.py +++ b/gai-backend/inference_models.py @@ -5,6 +5,7 @@ import re def validate_tool_name(v: str) -> str: + # Validate tool name according to Claude API requirements if not re.match(r'^[a-zA-Z0-9_-]{1,64}$', v): raise ValueError("Tool name must match regex ^[a-zA-Z0-9_-]{1,64}$") return v @@ -148,4 +149,4 @@ def __init__(self, status_code: int, detail: str): self.detail = detail class PricingError(Exception): - pass + pass \ No newline at end of file diff --git a/gai-backend/mcp_wrapper.py b/gai-backend/mcp_wrapper.py new file mode 100644 index 000000000..bb313b3c5 --- /dev/null +++ b/gai-backend/mcp_wrapper.py @@ -0,0 +1,321 @@ +""" +Simplified MCP implementation for direct subprocess communication. +This bypasses the SDK and uses direct JSON-RPC communication. +""" +import json +import logging +import asyncio +import subprocess +from typing import Dict, Any, Optional, List +from contextlib import asynccontextmanager + +logger = logging.getLogger(__name__) + +@asynccontextmanager +async def create_mcp_server(command, args, env=None): + """ + Start an MCP server and manage its lifecycle. + + Args: + command: Server command (e.g., 'node') + args: Command arguments + env: Environment variables + + Yields: + A connected server process + """ + logger.info(f"Starting MCP server process: {command} {' '.join(args)}") + + # Start the server process + process = subprocess.Popen( + [command] + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env + ) + + try: + # Wait for startup message with a short timeout using both stdout and stderr + logger.info("Waiting for server startup") + import select + + # Set a timeout for the startup message + timeout_seconds = 5 + start_time = asyncio.get_event_loop().time() + startup_msg = None + + # Non-blocking check for output + readable, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) + + # Check stderr for any messages + if process.stderr in readable: + stderr_msg = process.stderr.readline().strip() + if stderr_msg: + # Look for startup indicators in the stderr message + if "MCP server running" in stderr_msg or "running on stdio" in stderr_msg: + logger.info(f"Server startup message (from stderr): {stderr_msg}") + startup_msg = stderr_msg + else: + logger.warning(f"Server stderr: {stderr_msg}") + + # Also check stdout for startup messages + if process.stdout in readable: + stdout_msg = process.stdout.readline().strip() + if stdout_msg: + logger.info(f"Server startup message (from stdout): {stdout_msg}") + startup_msg = stdout_msg + + # If we got a message, note that startup was successful + if startup_msg: + logger.info("Server startup successful") + else: + logger.info("No explicit startup message detected, but proceeding anyway") + + # Yield the process for use + yield process + finally: + # Clean up the process + logger.info("Terminating MCP server process") + process.terminate() + try: + process.wait(timeout=2) + except subprocess.TimeoutExpired: + logger.warning("MCP server process did not terminate gracefully, killing") + process.kill() + +class MCPSession: + """ + A direct implementation of MCP using subprocess. + This works directly with the JSON-RPC protocol without the SDK. + """ + + def __init__(self, process): + """Initialize with a running MCP server process""" + self.process = process + self.request_id = 1 + + @classmethod + async def create(cls, process): + """ + Create and initialize an MCP session + + Args: + process: A running MCP server process + + Returns: + An initialized MCPSession + """ + try: + logger.debug("Creating direct MCP session") + session = cls(process) + + # Initialize the session + await session._initialize() + + return session + + except Exception as e: + logger.error(f"Error initializing MCP session: {e}", exc_info=True) + raise + + async def _initialize(self): + """Initialize the MCP session with direct JSON-RPC""" + logger.debug("Initializing MCP session") + + # Send initialization request + init_request = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "clientInfo": { + "name": "orchid-client", + "version": "1.0.0" + }, + "capabilities": { + "tools": {} + } + }, + "id": self.request_id + } + self.request_id += 1 + + logger.debug("Sending initialize request") + self.process.stdin.write(json.dumps(init_request) + "\n") + self.process.stdin.flush() + + # Read response + logger.debug("Waiting for initialize response") + init_response = self.process.stdout.readline() + logger.debug(f"Got initialize response: {init_response}") + + # Send initialized notification + init_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized" + } + logger.debug("Sending initialized notification") + self.process.stdin.write(json.dumps(init_notification) + "\n") + self.process.stdin.flush() + + # Give server a moment to process + await asyncio.sleep(0.5) + + logger.debug("MCP session initialized successfully") + + async def list_tools(self) -> List[Dict[str, Any]]: + """ + List available tools from the MCP server + + Returns: + List of tool definitions + """ + try: + logger.debug("Listing MCP tools") + + # Create list request + list_request = { + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": self.request_id + } + self.request_id += 1 + + # Send request + logger.debug("Sending tools/list request") + self.process.stdin.write(json.dumps(list_request) + "\n") + self.process.stdin.flush() + + # Read response + logger.debug("Waiting for tools/list response") + list_response = self.process.stdout.readline() + logger.debug(f"Got tools/list response: {list_response}") + + # Parse response + try: + if not list_response.strip(): + logger.error("Empty response from MCP server when listing tools") + return [] + + response_data = json.loads(list_response) + if "error" in response_data: + logger.error(f"Error listing tools: {response_data['error']}") + return [] + + if "result" in response_data and "tools" in response_data["result"]: + tools = response_data["result"]["tools"] + tool_names = [tool.get('name', 'unnamed') for tool in tools] + logger.info(f"Found {len(tools)} tools: {tool_names}") + + # Convert to the expected structure for tool_registry + # The tool registry expects objects with a name attribute, not dictionaries + class ToolObj: + def __init__(self, name, description): + self.name = name + self.description = description + + # Create proper tool objects + tool_objects = [] + for tool in tools: + tool_obj = ToolObj( + name=tool.get('name', 'unnamed'), + description=tool.get('description', '') + ) + tool_objects.append(tool_obj) + + # We return this for tool_registry compatibility + class ToolsResponse: + def __init__(self, tools): + self.tools = tools + + return ToolsResponse(tool_objects) + + logger.warning("No tools found in MCP server response") + return [] + + except json.JSONDecodeError: + logger.error("Could not parse tools response as JSON") + return [] + + except Exception as e: + logger.error(f"Error listing MCP tools: {e}", exc_info=True) + return [] + + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + Call a tool with the given arguments + + Args: + name: Name of the tool to call + arguments: Arguments to pass to the tool + + Returns: + Tool execution result + """ + try: + # Log tool arguments at a higher level for debugging + logger.info(f"Tool call to {name} with arguments: {json.dumps(arguments)}") + + logger.debug(f"Calling MCP tool {name} with arguments: {arguments}") + + # Create call request + call_request = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": name, + "arguments": arguments + }, + "id": self.request_id + } + self.request_id += 1 + + # Send request + logger.debug("Sending tools/call request") + self.process.stdin.write(json.dumps(call_request) + "\n") + self.process.stdin.flush() + + # Read response + logger.debug("Waiting for tools/call response") + call_response = self.process.stdout.readline() + + # Truncate long responses in logs + log_response = call_response[:1000] + "..." if len(call_response) > 1000 else call_response + logger.debug(f"Got tools/call response (truncated): {log_response}") + + # Parse response + try: + response_data = json.loads(call_response) + if "error" in response_data: + error_msg = response_data.get("error", {}).get("message", "Unknown error") + logger.error(f"Error calling tool {name}: {error_msg}") + return {"content": [{"type": "text", "text": f"Error: {error_msg}"}]} + + if "result" in response_data: + result = response_data["result"] + + # Check for any content truncation and log it + if "content" in result: + for content in result["content"]: + if content.get("type") == "text": + text = content.get("text", "") + + # Look for any truncation indicators + if "Content truncated" in text: + logger.info(f"Detected content truncation in tool result for {name}") + + return result + + logger.warning(f"No result found in tool call response for {name}") + return {"content": [{"type": "text", "text": "No result returned from tool"}]} + + except json.JSONDecodeError: + logger.error("Could not parse tool call response as JSON") + return {"content": [{"type": "text", "text": "Error: Invalid response format"}]} + + except Exception as e: + logger.error(f"Error calling MCP tool {name}: {e}", exc_info=True) + return {"content": [{"type": "text", "text": f"Error: {str(e)}"}]} \ No newline at end of file diff --git a/gai-backend/tool_config_example.json b/gai-backend/tool_config_example.json new file mode 100644 index 000000000..1687e7bc6 --- /dev/null +++ b/gai-backend/tool_config_example.json @@ -0,0 +1,184 @@ +{ + "inference": { + "api_url": "https://api.example.com/v1", + "endpoints": { + "anthropic": { + "api_type": "anthropic", + "url": "https://api.anthropic.com/v1/messages", + "api_key": "YOUR_ANTHROPIC_API_KEY", + "models": [ + { + "id": "claude-3-opus-20240229", + "display_name": "Claude 3 Opus", + "pricing": { + "type": "fixed", + "input_price": 15000, + "output_price": 75000 + } + }, + { + "id": "claude-3-sonnet-20240229", + "display_name": "Claude 3 Sonnet", + "pricing": { + "type": "fixed", + "input_price": 3000, + "output_price": 15000 + } + } + ] + } + }, + "tools": { + "enabled": true, + "inject_defaults": true, + "default_timeout_ms": 5000, + "refund_on_error": true, + "refund_on_timeout": true, + "mcp_servers": { + "exa-mcp": { + "command": "node", + "args": ["/opt/homebrew/lib/node_modules/exa-mcp-server/build/index.js"], + "env": {"EXA_API_KEY": "your-exa-api-key-here"} + }, + "weather-mcp": { + "command": "node", + "args": ["/opt/homebrew/lib/node_modules/weather-mcp/index.js"], + "env": {"WEATHER_API_KEY": "your-weather-api-key-here"} + }, + "database-mcp": { + "command": "python", + "args": ["-m", "database_mcp.server"], + "env": {"DB_CONNECTION_STRING": "postgresql://user:pass@localhost/db"} + } + }, + "registry": { + "web_search": { + "enabled": true, + "type": "mcp", + "server": "exa-mcp", + "billing_type": "web_search", + "description": "Search the web for current information on a given query.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query. Be specific and concise." + } + }, + "required": ["query"] + } + }, + "image_search": { + "enabled": true, + "type": "mcp", + "server": "exa-mcp", + "billing_type": "web_search", + "description": "Search for images online.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The image search query." + }, + "count": { + "type": "integer", + "description": "Number of images to return (max 10).", + "default": 3 + } + }, + "required": ["query"] + } + }, + "unit_converter": { + "enabled": true, + "type": "python", + "module": "tool_unit_converter", + "billing_type": "function_call", + "description": "Convert values between different units of measurement. Supports length, mass, volume, temperature, area, and time conversions.", + "parameters": { + "type": "object", + "properties": { + "value": { + "type": "number", + "description": "The numeric value to convert" + }, + "from_unit": { + "type": "string", + "description": "The source unit (e.g., 'm', 'kg', 'c')" + }, + "to_unit": { + "type": "string", + "description": "The target unit (e.g., 'ft', 'lb', 'f')" + }, + "unit_type": { + "type": "string", + "description": "The type of conversion to perform", + "enum": ["length", "mass", "volume", "temperature", "area", "time"] + } + }, + "required": ["value", "from_unit", "to_unit", "unit_type"] + } + }, + "weather": { + "enabled": true, + "type": "mcp", + "server": "weather-mcp", + "billing_type": "api_call", + "description": "Get current weather conditions for a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or geographic coordinates." + }, + "units": { + "type": "string", + "description": "Units of measurement (metric, imperial, standard).", + "enum": ["metric", "imperial", "standard"], + "default": "metric" + } + }, + "required": ["location"] + } + }, + "database_query": { + "enabled": true, + "type": "mcp", + "server": "database-mcp", + "billing_type": "data_query", + "description": "Query the application database (read-only).", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL query to execute. SELECT statements only." + }, + "limit": { + "type": "integer", + "description": "Maximum number of rows to return.", + "default": 100 + } + }, + "required": ["query"] + } + } + } + } + }, + "billing": { + "prices": { + "invoice": 1.0, + "payment": 1.0, + "auth_token": 1.0, + "error": 1.0, + "function_call": 5.0, + "web_search": 10.0, + "api_call": 8.0, + "data_query": 15.0 + } + } +} \ No newline at end of file diff --git a/gai-backend/tool_executor.py b/gai-backend/tool_executor.py new file mode 100644 index 000000000..28e7df99a --- /dev/null +++ b/gai-backend/tool_executor.py @@ -0,0 +1,186 @@ +""" +Tool execution and billing integration. +""" +import logging +import uuid +import json +import asyncio +from typing import Dict, Any, List, Optional +from datetime import datetime + +from tool_registry import ToolRegistry, ToolExecutionError +from billing import StrictRedisBilling +from inference_models import Message + +logger = logging.getLogger(__name__) + +class ToolExecutor: + """ + Handles execution of tools with proper billing and tracking. + """ + + def __init__(self, billing: StrictRedisBilling, config: Dict[str, Any]): + self.billing = billing + self.config = config + self.registry = ToolRegistry() + self.default_timeout = config.get('inference', {}).get( + 'tools', {}).get('default_timeout_ms', 5000) / 1000 + + async def get_tool_price(self, tool_name: str) -> float: + """Get the price for executing a tool.""" + tool = self.registry.get_tool(tool_name) + if not tool: + return 0.0 + + billing_type = tool.config.get('billing_type', 'function_call') + billing_prices = self.config.get('billing', {}).get('prices', {}) + + # Get price from billing config + price = billing_prices.get(billing_type, 0.0) + return price + + async def execute_tool_with_billing( + self, + session_id: str, + tool_call_id: str, + tool_name: str, + arguments: Dict[str, Any], + context: Optional[Dict[str, Any]] = None + ) -> str: + """ + Execute a tool with proper billing and error handling. + + Args: + session_id: Client session ID for billing + tool_call_id: Unique ID for this tool call + tool_name: Name of the tool to execute + arguments: Arguments to pass to the tool + context: Optional execution context + + Returns: + The result of the tool execution as a string + """ + if not context: + context = {} + + # Add execution metadata to context + context.update({ + 'session_id': session_id, + 'tool_call_id': tool_call_id, + 'timestamp': datetime.now().isoformat(), + 'request_id': context.get('request_id', str(uuid.uuid4())) + }) + + # Get the tool (either by namespaced name or by original name) + tool = self.registry.get_tool(tool_name) + if not tool: + # Try lookup by original name for backward compatibility + original_name_matches = [t for t in self.registry._tools.values() + if getattr(t, 'original_name', None) == tool_name] + if original_name_matches: + tool = original_name_matches[0] + logger.info(f"Found tool by original name {tool_name}, using namespaced name {tool.name}") + tool_name = tool.name + + # Calculate and pre-authorize the cost + price = await self.get_tool_price(tool_name) + + if price > 0: + # Debit the account + await self.billing.debit(session_id, amount=price) + logger.info(f"Debited {price} from session {session_id} for tool {tool_name}") + + try: + # Execute the tool with timeout + result = await asyncio.wait_for( + self.registry.execute_tool(tool_name, arguments, context), + timeout=self.default_timeout + ) + return result + + except asyncio.TimeoutError: + display_name = tool.display_name if tool else tool_name + logger.error(f"Tool execution timed out: {display_name}") + # Refund on timeout if configured + if price > 0 and self.config.get('inference', {}).get('tools', {}).get('refund_on_timeout', True): + await self.billing.credit(session_id, amount=price) + logger.info(f"Refunded {price} to session {session_id} for timed out tool {display_name}") + return "Error: Tool execution timed out" + + except ToolExecutionError as e: + display_name = tool.display_name if tool else tool_name + logger.error(f"Tool execution error for {display_name}: {e}") + # Refund on error if configured + if price > 0 and self.config.get('inference', {}).get('tools', {}).get('refund_on_error', True): + await self.billing.credit(session_id, amount=price) + logger.info(f"Refunded {price} to session {session_id} for failed tool {display_name}") + return f"Error: {str(e)}" + + except Exception as e: + display_name = tool.display_name if tool else tool_name + logger.error(f"Unexpected error in tool execution for {display_name}: {e}") + # Refund on error if configured + if price > 0 and self.config.get('inference', {}).get('tools', {}).get('refund_on_error', True): + await self.billing.credit(session_id, amount=price) + logger.info(f"Refunded {price} to session {session_id} for failed tool {display_name}") + return f"Error: Unexpected error during tool execution" + + async def execute_all_tools( + self, + session_id: str, + tool_calls: List[Dict[str, Any]], + context: Optional[Dict[str, Any]] = None + ) -> List[Message]: + """ + Execute multiple tool calls and construct tool response messages. + + Args: + session_id: Client session ID for billing + tool_calls: List of tool calls from the model + context: Optional execution context + + Returns: + List of tool response messages + """ + if not context: + context = {} + + # Run all tool calls in parallel + tasks = [] + for tool_call in tool_calls: + if tool_call["type"] == "function": + function_call = tool_call["function"] + try: + arguments = json.loads(function_call.get("arguments", "{}")) + except json.JSONDecodeError: + arguments = {} + + tasks.append( + self.execute_tool_with_billing( + session_id=session_id, + tool_call_id=tool_call["id"], + tool_name=function_call["name"], + arguments=arguments, + context=context + ) + ) + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Construct tool response messages + messages = [] + for tool_call, result in zip(tool_calls, results): + # Handle exceptions + if isinstance(result, Exception): + content = f"Error: {str(result)}" + else: + content = result + + messages.append(Message( + role="tool", + content=content, + tool_call_id=tool_call["id"] + )) + + return messages \ No newline at end of file diff --git a/gai-backend/tool_registry.py b/gai-backend/tool_registry.py new file mode 100644 index 000000000..0f975f6a8 --- /dev/null +++ b/gai-backend/tool_registry.py @@ -0,0 +1,450 @@ +from typing import Dict, Any, List, Optional, Callable, Awaitable, Union, Type +import importlib +import inspect +import json +import logging +from mcp import ClientSession +from inference_models import FunctionDefinition, Tool + +logger = logging.getLogger(__name__) + +class ToolExecutionError(Exception): + """Error occurred during tool execution""" + pass + +class ToolDefinitionError(Exception): + """Error in tool definition""" + pass + +class BaseTool: + """Base class for all tools""" + + def __init__(self, name: str, config: Dict[str, Any]): + self.name = name + self.config = config + self.original_name = getattr(self, 'original_name', name) + + def get_definition(self) -> FunctionDefinition: + """Return the tool definition to be exposed to the LLM""" + raise NotImplementedError() + + async def execute(self, arguments: Dict[str, Any], context: Dict[str, Any]) -> str: + """Execute the tool with the given arguments""" + raise NotImplementedError() + + @property + def is_available(self) -> bool: + """Check if the tool is currently available""" + return True + + @property + def display_name(self) -> str: + """ + Return the preferred name to be displayed in logs and debugging. + This helps to identify the tool with its namespaced identifier + and original name (if different). + """ + if self.original_name != self.name: + return f"{self.name} (original: {self.original_name})" + return self.name + +class PythonTool(BaseTool): + """Tool implemented as a Python function""" + + def __init__(self, name: str, config: Dict[str, Any], func: Callable): + # Generate namespaced name for internal use: python____ + module_path = config.get("module", "unknown") + module_name = module_path.split(".")[-1] # Use last part of module path + namespaced_name = f"python__{module_name}__{name}" + + # Store original name for debugging/logging + self.original_name = name + + super().__init__(namespaced_name, config) + self.func = func + self.description = config.get("description") or inspect.getdoc(func) + self.parameters = config.get("parameters") + + def get_definition(self) -> FunctionDefinition: + return FunctionDefinition( + name=self.name, + description=self.description, + parameters=self.parameters or { + "type": "object", + "properties": {}, + "required": [] + } + ) + + async def execute(self, arguments: Dict[str, Any], context: Dict[str, Any]) -> str: + try: + result = self.func(arguments, context) + if inspect.isawaitable(result): + result = await result + return str(result) + except Exception as e: + logger.error(f"Error executing Python tool {self.name} (original: {self.original_name}): {str(e)}") + raise ToolExecutionError(f"Tool execution failed: {str(e)}") + +class MCPTool(BaseTool): + """Tool implemented via MCP protocol""" + + def __init__(self, name: str, config: Dict[str, Any], session): + """ + Initialize an MCP tool + + Args: + name: Tool name + config: Tool configuration + session: MCPSession or ClientSession instance + """ + # Generate namespaced name for internal use: mcp____ + server = config.get("server", "unknown") + namespaced_name = f"mcp__{server}__{name}" + + # Store original name for debugging/logging + self.original_name = name + + super().__init__(namespaced_name, config) + self.session = session + self.description = config.get("description", f"Execute the {name} tool") + self.parameters = config.get("parameters", { + "type": "object", + "properties": {}, + "required": [] + }) + + def get_definition(self) -> FunctionDefinition: + return FunctionDefinition( + name=self.name, + description=self.description, + parameters=self.parameters + ) + + async def execute(self, arguments: Dict[str, Any], context: Dict[str, Any]) -> str: + try: + # Check if we have our wrapper or direct ClientSession + logger.debug(f"Executing MCP tool {self.name} (original: {self.original_name}) with arguments: {arguments}") + + # For MCP tool calls, we need to use the original tool name when calling the MCP server + tool_name = self.original_name + + # Handle both MCPSession and ClientSession + from mcp_wrapper import MCPSession + if isinstance(self.session, MCPSession): + # Use our wrapper + result = await self.session.call_tool(tool_name, arguments) + + # Extract text from content array + if result and "content" in result: + text_contents = [] + for content in result["content"]: + if content.get("type") == "text": + text_contents.append(content.get("text", "")) + return "\n".join(text_contents) + elif hasattr(self.session, 'tools') and hasattr(self.session.tools, 'call'): + # Use standard ClientSession + response = await self.session.tools.call(tool_name, arguments) + + if response and hasattr(response, 'result') and hasattr(response.result, 'content'): + # Extract text content from MCP response + text_contents = [] + for content in response.result.content: + if hasattr(content, 'type') and content.type == "text" and hasattr(content, 'text'): + text_contents.append(content.text) + return "\n".join(text_contents) + elif hasattr(self.session, 'call_tool'): + # Direct subprocess wrapper + result = await self.session.call_tool(tool_name, arguments) + + # Extract text from content array + if result and "content" in result: + text_contents = [] + for content in result["content"]: + if content.get("type") == "text": + text_contents.append(content.get("text", "")) + return "\n".join(text_contents) + + return "No results found" + + except Exception as e: + logger.error(f"Error executing MCP tool {self.name} (original: {self.original_name}): {str(e)}") + raise ToolExecutionError(f"Tool execution failed: {str(e)}") + +class ToolRegistry: + """Registry for all available tools""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ToolRegistry, cls).__new__(cls) + cls._instance._tools = {} + cls._instance._mcp_sessions = {} + return cls._instance + + def register_tool(self, tool: BaseTool) -> None: + """Register a new tool""" + if tool.name in self._tools: + logger.warning(f"Tool {tool.display_name} already registered, replacing") + self._tools[tool.name] = tool + logger.info(f"Registered tool: {tool.display_name}") + + def unregister_tool(self, name: str) -> None: + """Unregister a tool""" + if name in self._tools: + logger.info(f"Unregistered tool: {self._tools[name].display_name}") + del self._tools[name] + + def get_tool(self, name: str) -> Optional[BaseTool]: + """Get a tool by name""" + return self._tools.get(name) + + def list_tools(self) -> List[str]: + """List all registered tools with their namespaced names""" + return list(self._tools.keys()) + + def list_tools_with_details(self) -> List[Dict[str, Any]]: + """List all registered tools with detailed information""" + return [ + { + "namespaced_name": t.name, + "original_name": getattr(t, "original_name", t.name), + "type": t.__class__.__name__, + "server": t.config.get("server") if hasattr(t, "config") else None, + "module": t.config.get("module") if hasattr(t, "config") else None, + "description": t.description if hasattr(t, "description") else None, + "available": t.is_available, + } + for t in self._tools.values() + ] + + def get_available_tools(self) -> List[Tool]: + """Get all available tools in the format expected by the LLM""" + available_tools = [] + for name, tool in self._tools.items(): + if tool.is_available: + available_tools.append(Tool( + type="function", + function=tool.get_definition() + )) + return available_tools + + def register_mcp_session(self, name: str, session: ClientSession) -> None: + """Register an MCP session for tool providers""" + self._mcp_sessions[name] = session + + def get_mcp_session(self, name: str) -> Optional[ClientSession]: + """Get an MCP session by name""" + return self._mcp_sessions.get(name) + + async def init_from_config(self, config: Dict[str, Any]) -> None: + """Initialize tools from configuration""" + if not config or 'tools' not in config.get('inference', {}): + logger.warning("No tool configuration found") + return + + tool_config = config['inference']['tools'] + if not tool_config.get('enabled', False): + logger.info("Tools are disabled in configuration") + return + + logger.info(f"Initializing tools from config, enabled: {tool_config.get('enabled')}") + + registry = tool_config.get('registry', {}) + logger.info(f"Found {len(registry)} tools in registry configuration") + + # Track successfully registered tools + successful_tools = 0 + + for name, tool_def in registry.items(): + if not tool_def.get('enabled', True): + logger.info(f"Tool {name} is disabled in configuration") + continue + + logger.info(f"Registering tool {name} of type {tool_def.get('type')}") + + try: + if tool_def.get('type') == 'python': + # Python tool + module_path = tool_def.get('module') + if not module_path: + logger.error(f"Missing module path for Python tool {name}") + continue + + logger.info(f"Loading Python module {module_path} for tool {name}") + + module = importlib.import_module(module_path) + func = getattr(module, 'execute', None) + if not func: + logger.error(f"Module {module_path} does not have an execute function") + continue + + self.register_tool(PythonTool(name, tool_def, func)) + logger.info(f"Successfully registered Python tool: {name}") + successful_tools += 1 + + elif tool_def.get('type') == 'mcp': + # MCP tool - requires session + server = tool_def.get('server') + if not server: + logger.error(f"No server specified for MCP tool {name}") + continue + + # Detailed logging about MCP sessions + logger.debug(f"Available MCP sessions: {list(self._mcp_sessions.keys())}") + + # Check if we have a session for this server + if server not in self._mcp_sessions: + logger.error(f"MCP server '{server}' not found for tool {name}. Available servers: {list(self._mcp_sessions.keys())}") + continue + + logger.debug(f"Creating MCP tool {name} using server {server}") + + # Get session and verify it + session = self._mcp_sessions.get(server) + if not session: + logger.error(f"MCP session for server '{server}' is None") + continue + + logger.debug(f"MCP session for {server} obtained") + + # Verify the session is active before registering + try: + # Try to list tools to verify the session is working correctly + import asyncio + logger.debug(f"Testing MCP session responsiveness for {server}") + + max_retries = 2 + retry_count = 0 + tool_listing_success = False + server_tool_names = [] + + while retry_count <= max_retries and not tool_listing_success: + try: + # Use a reasonable timeout to prevent hanging + async with asyncio.timeout(10): + try: + # Check what type of session we have + from mcp_wrapper import MCPSession + if isinstance(session, MCPSession): + logger.debug(f"Using MCPSession list_tools method for {server} (attempt {retry_count+1})") + tools_response = await session.list_tools() + elif hasattr(session, 'tools') and hasattr(session.tools, 'list'): + logger.debug(f"Using standard tools.list method for {server} (attempt {retry_count+1})") + tools_response = await session.tools.list() + + # Check if we have tools in the response + if tools_response: + # Handle both object and list responses (our wrapper returns an object) + if hasattr(tools_response, "tools"): + tool_count = len(tools_response.tools) + tool_names = [t.name for t in tools_response.tools] + server_tool_names = tool_names # Save for validation later + logger.info(f"MCP server {server} has {tool_count} tools: {tool_names}") + + # Check for mismatches between server tools and registry + registry_tools = [name for name, tool in registry.items() + if tool.get('type') == 'mcp' and tool.get('server') == server] + + # Find tools in registry but not in server + missing_tools = set(registry_tools) - set(tool_names) + if missing_tools: + logger.error(f"MCP server {server} is missing tools defined in registry: {missing_tools}") + logger.error(f"Please update Redis config to remove these tools or fix the MCP server") + + # Find tools in server but not in registry + unknown_tools = set(tool_names) - set(registry_tools) + if unknown_tools: + logger.error(f"MCP server {server} has tools not defined in registry: {unknown_tools}") + logger.error(f"Please update Redis config to add these tools") + + tool_listing_success = True + elif isinstance(tools_response, list) and len(tools_response) > 0: + tool_count = len(tools_response) + tool_names = [t.get('name', 'unnamed') for t in tools_response] + server_tool_names = tool_names # Save for validation later + logger.info(f"MCP server {server} has {tool_count} tools: {tool_names}") + + # Process registry mismatch as above + registry_tools = [name for name, tool in registry.items() + if tool.get('type') == 'mcp' and tool.get('server') == server] + + missing_tools = set(registry_tools) - set(tool_names) + if missing_tools: + logger.error(f"MCP server {server} is missing tools defined in registry: {missing_tools}") + + unknown_tools = set(tool_names) - set(registry_tools) + if unknown_tools: + logger.error(f"MCP server {server} has tools not defined in registry: {unknown_tools}") + + tool_listing_success = True + else: + logger.warning(f"MCP server {server} returned unexpected tools format: {type(tools_response)}") + else: + logger.warning(f"MCP server {server} returned empty or invalid tools list") + + except Exception as e: + logger.warning(f"Error listing tools for {server} (attempt {retry_count+1}): {e}") + + # If we failed but have retries left, wait a bit before trying again + if not tool_listing_success: + retry_count += 1 + if retry_count <= max_retries: + logger.debug(f"Waiting before retry {retry_count} for {server}") + await asyncio.sleep(1.5) + + except asyncio.TimeoutError: + retry_count += 1 + logger.warning(f"Timeout listing tools for {server} (attempt {retry_count})") + if retry_count <= max_retries: + await asyncio.sleep(1.5) + + # Verify the tool exists in the MCP server before registering + if name not in server_tool_names: + logger.error(f"Tool {name} defined in registry but not found in MCP server {server}") + logger.error(f"Available tools from server: {server_tool_names}") + logger.error(f"Skipping registration of tool {name}") + continue + + # Create and register the tool + self.register_tool(MCPTool(name, tool_def, session)) + logger.info(f"Successfully registered MCP tool: {name}") + successful_tools += 1 + + except Exception as e: + logger.error(f"Error verifying MCP session for {server}: {e}", exc_info=True) + continue + + else: + logger.warning(f"Unknown tool type for {name}: {tool_def.get('type')}") + + except Exception as e: + logger.error(f"Error registering tool {name}: {str(e)}", exc_info=True) + + # Log summary + logger.info(f"Tool initialization complete. Registered {successful_tools}/{len(registry)} tools") + + # Log a summary of registered tools + all_tools = self.list_tools_with_details() + tool_summary = [f"{t['namespaced_name']} (original: {t['original_name']})" for t in all_tools] + logger.info(f"Registered tools: {', '.join(tool_summary)}") + + async def execute_tool(self, name: str, arguments: Dict[str, Any], + context: Dict[str, Any]) -> str: + """Execute a tool by name""" + tool = self.get_tool(name) + if not tool: + # Try lookup by original name for backward compatibility + original_name_matches = [t for t in self._tools.values() + if getattr(t, 'original_name', None) == name] + if original_name_matches: + tool = original_name_matches[0] + logger.info(f"Found tool by original name {name}, mapped to namespaced name {tool.name}") + else: + raise ToolExecutionError(f"Tool not found: {name}") + + if not tool.is_available: + raise ToolExecutionError(f"Tool is not available: {tool.display_name}") + + return await tool.execute(arguments, context) \ No newline at end of file diff --git a/gai-backend/tool_unit_converter.py b/gai-backend/tool_unit_converter.py new file mode 100644 index 000000000..29b2c7361 --- /dev/null +++ b/gai-backend/tool_unit_converter.py @@ -0,0 +1,176 @@ +""" +Unit conversion tool implementation. +""" +import logging +from typing import Dict, Any, Union, Optional + +logger = logging.getLogger(__name__) + +# Conversion factors for different unit types +CONVERSIONS = { + "length": { + "m": 1.0, # base unit (meters) + "km": 1000.0, + "cm": 0.01, + "mm": 0.001, + "in": 0.0254, + "ft": 0.3048, + "yd": 0.9144, + "mi": 1609.34 + }, + "mass": { + "kg": 1.0, # base unit (kilograms) + "g": 0.001, + "mg": 0.000001, + "lb": 0.453592, + "oz": 0.0283495 + }, + "volume": { + "l": 1.0, # base unit (liters) + "ml": 0.001, + "gal": 3.78541, + "qt": 0.946353, + "pt": 0.473176, + "cup": 0.236588, + "floz": 0.0295735 + }, + "temperature": { + # Special case, handled separately + "c": "celsius", + "f": "fahrenheit", + "k": "kelvin" + }, + "area": { + "m2": 1.0, # base unit (square meters) + "km2": 1000000.0, + "cm2": 0.0001, + "mm2": 0.000001, + "in2": 0.00064516, + "ft2": 0.092903, + "ac": 4046.86, + "ha": 10000.0 + }, + "time": { + "s": 1.0, # base unit (seconds) + "ms": 0.001, + "min": 60.0, + "h": 3600.0, + "day": 86400.0, + "week": 604800.0 + } +} + +def convert_temperature(value: float, from_unit: str, to_unit: str) -> float: + """Special handling for temperature conversions.""" + # Convert to Kelvin first (as the intermediate unit) + if from_unit == "c": + kelvin = value + 273.15 + elif from_unit == "f": + kelvin = (value - 32) * 5/9 + 273.15 + else: # already kelvin + kelvin = value + + # Convert from Kelvin to the target unit + if to_unit == "c": + return kelvin - 273.15 + elif to_unit == "f": + return (kelvin - 273.15) * 9/5 + 32 + else: # to kelvin + return kelvin + +def convert_unit(value: float, from_unit: str, to_unit: str, unit_type: str) -> Optional[float]: + """Convert a value from one unit to another within the same type.""" + from_unit = from_unit.lower() + to_unit = to_unit.lower() + + # Handle temperature separately + if unit_type == "temperature": + return convert_temperature(value, from_unit, to_unit) + + # For other unit types + if unit_type in CONVERSIONS and from_unit in CONVERSIONS[unit_type] and to_unit in CONVERSIONS[unit_type]: + # Convert to the base unit first, then to the target unit + base_value = value * CONVERSIONS[unit_type][from_unit] + return base_value / CONVERSIONS[unit_type][to_unit] + + return None + +def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: + """ + Execute the unit conversion tool. + + Args: + args: Dictionary with 'value', 'from_unit', 'to_unit', and 'unit_type' keys + context: Execution context (not used for unit conversion) + + Returns: + String representation of the converted value + """ + # Extract arguments + value = args.get('value') + from_unit = args.get('from_unit') + to_unit = args.get('to_unit') + unit_type = args.get('unit_type') + + # Validate arguments + if value is None or from_unit is None or to_unit is None or unit_type is None: + return "Error: Missing required parameters (value, from_unit, to_unit, unit_type)" + + try: + value = float(value) + except ValueError: + return "Error: Value must be a number" + + if unit_type not in CONVERSIONS: + return f"Error: Unsupported unit type. Supported types: {', '.join(CONVERSIONS.keys())}" + + # Normalize unit names and check if they are supported + from_unit = from_unit.lower() + to_unit = to_unit.lower() + + unit_map = CONVERSIONS[unit_type] + if from_unit not in unit_map: + return f"Error: Unsupported from_unit. Supported units for {unit_type}: {', '.join(unit_map.keys())}" + + if to_unit not in unit_map: + return f"Error: Unsupported to_unit. Supported units for {unit_type}: {', '.join(unit_map.keys())}" + + # Perform the conversion + result = convert_unit(value, from_unit, to_unit, unit_type) + if result is None: + return "Error: Conversion failed" + + # Format the result (round to 6 decimal places if needed) + if result == int(result): + return str(int(result)) + else: + return str(round(result, 6)) + +# Tool definition for configuration +TOOL_DEFINITION = { + "name": "unit_converter", + "description": "Convert values between different units of measurement. Supports length, mass, volume, temperature, area, and time conversions.", + "parameters": { + "type": "object", + "properties": { + "value": { + "type": "number", + "description": "The numeric value to convert" + }, + "from_unit": { + "type": "string", + "description": "The source unit (e.g., 'm', 'kg', 'c')" + }, + "to_unit": { + "type": "string", + "description": "The target unit (e.g., 'ft', 'lb', 'f')" + }, + "unit_type": { + "type": "string", + "description": "The type of conversion to perform", + "enum": ["length", "mass", "volume", "temperature", "area", "time"] + } + }, + "required": ["value", "from_unit", "to_unit", "unit_type"] + } +} \ No newline at end of file