Skip to content

Commit

Permalink
Introduce tool injection
Browse files Browse the repository at this point in the history
  • Loading branch information
danopato committed Feb 28, 2025
1 parent d7da450 commit c72f6de
Show file tree
Hide file tree
Showing 11 changed files with 1,988 additions and 99 deletions.
31 changes: 30 additions & 1 deletion gai-backend/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import json
import time
import os
import logging

logger = logging.getLogger("config_manager")

class ConfigError(Exception):
"""Raised when config operations fail"""
Expand Down Expand Up @@ -118,28 +121,54 @@ async def write_config(self, config: Dict[str, Any], force: bool = False):

async def load_config(self, config_path: Optional[str] = None, force_reload: bool = False) -> Dict[str, Any]:
try:
logger.debug(f"Loading config - path: {config_path}, force_reload: {force_reload}")

if config_path:
logger.debug(f"Loading config from file: {config_path}")
config = await self.load_from_file(config_path)
await self.write_config(config, force=True)
logger.debug("Loaded and wrote config from file")
return config

timestamp = await self.redis.get("config:last_update")
logger.debug(f"Redis config timestamp: {timestamp}, last_load_time: {self.last_load_time}")

if not force_reload and timestamp and self.last_load_time >= float(timestamp):
logger.debug("Using cached config (no changes detected)")
return self.current_config

config_data = await self.redis.get("config:data")
if not config_data:
raise ConfigError("No configuration found in Redis")
logger.warning("No configuration found in Redis")
# Provide a minimal default config instead of raising an error
default_config = {
"inference": {
"endpoints": {},
"tools": {
"enabled": False,
"inject_defaults": False
}
}
}
logger.debug("Using minimal default configuration")
self.current_config = default_config
return default_config

logger.debug("Loading config from Redis")
config = json.loads(config_data)
config = self.process_config(config)

if "tools" in config.get("inference", {}):
tools_cfg = config["inference"]["tools"]
logger.debug(f"Tools config loaded - enabled: {tools_cfg.get('enabled')}, inject_defaults: {tools_cfg.get('inject_defaults')}")

self.current_config = config
self.last_load_time = float(timestamp) if timestamp else time.time()
logger.debug("Successfully loaded config from Redis")
return config

except Exception as e:
logger.error(f"Failed to load config: {e}", exc_info=True)
raise ConfigError(f"Failed to load config: {e}")

async def check_for_updates(self) -> bool:
Expand Down
27 changes: 21 additions & 6 deletions gai-backend/inference_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
logger = logging.getLogger(__name__)

class ModelAdapter:
logger = logging.getLogger(__name__)

@staticmethod
def parse_response(api_type: str, response: Dict[str, Any], model: str, request_id: str) -> ChatCompletion:
if api_type in ('openai', 'openrouter'):
Expand Down Expand Up @@ -72,8 +74,8 @@ def parse_anthropic_response(response: Dict[str, Any], model: str, response_id:
system_fingerprint=response.get('system_fingerprint')
)

@staticmethod
def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]:
@classmethod
def prepare_anthropic_request(cls, model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]:
headers = {
"Content-Type": "application/json",
"x-api-key": api_key,
Expand Down Expand Up @@ -129,6 +131,7 @@ def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, reques
# Handle tools
tools = request.get_effective_tools()
if tools:
cls.logger.debug(f"Including {len(tools)} tools in Anthropic request")
data["tools"] = []
for tool in tools:
tool_def = {
Expand All @@ -137,8 +140,9 @@ def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, reques
"input_schema": tool.function.parameters
}
data["tools"].append(tool_def)

# Handle tool choice
cls.logger.debug(f"Tools in Anthropic request: {json.dumps(data['tools'])}")

# Handle tool choice - only when tools are present
tool_choice = request.get_effective_tool_choice()
if isinstance(tool_choice, str):
data["tool_choice"] = "none" if tool_choice == "none" else {"type": "auto"}
Expand All @@ -149,6 +153,8 @@ def prepare_anthropic_request(model_config: Dict[str, Any], api_key: str, reques
}
else:
data["tool_choice"] = {"type": "auto"}
else:
cls.logger.debug("No tools to include in Anthropic request")

# Copy other parameters
if request.temperature is not None:
Expand Down Expand Up @@ -188,8 +194,8 @@ def parse_openai_response(response: Dict[str, Any], model: str, response_id: str
)
)

@staticmethod
def prepare_openai_request(model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]:
@classmethod
def prepare_openai_request(cls, model_config: Dict[str, Any], api_key: str, request: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str]]:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
Expand All @@ -198,6 +204,15 @@ def prepare_openai_request(model_config: Dict[str, Any], api_key: str, request:
data = request.dict(exclude_none=True, exclude={'request_id'})
data['model'] = model_config['id']

# Log tools information
tools = request.get_effective_tools()
if tools:
cls.logger.debug(f"Including {len(tools)} tools in OpenAI request")
for tool in tools:
cls.logger.debug(f"Tool in OpenAI request: {tool.function.name}")
else:
cls.logger.debug("No tools to include in OpenAI request")

return data, headers

@classmethod
Expand Down
10 changes: 8 additions & 2 deletions gai-backend/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,23 @@ async def chat_completion(
@app.get("/v1/models")
async def list_openai_models():
try:
logger.debug("Handling request to /v1/models")
return await api.list_openai_models()
except Exception as e:
logger.error(f"Error listing OpenAI models: {e}", exc_info=True)
if isinstance(e, InferenceError):
raise
raise HTTPException(status_code=500, detail=str(e))
# Return empty list instead of error
return {"data": []}

@app.get("/v1/inference/models")
async def list_inference_models():
try:
logger.debug("Handling request to /v1/inference/models")
return await api.list_models()
except Exception as e:
logger.error(f"Error listing inference models: {e}", exc_info=True)
if isinstance(e, InferenceError):
raise
raise HTTPException(status_code=500, detail=str(e))
# Return empty dict instead of error
return {}
Loading

0 comments on commit c72f6de

Please sign in to comment.