diff --git a/gai-backend/config_manager.py b/gai-backend/config_manager.py index d7aa9d608..e549e74ae 100644 --- a/gai-backend/config_manager.py +++ b/gai-backend/config_manager.py @@ -7,6 +7,9 @@ logger = logging.getLogger("config_manager") +# Environment variable for API keys JSON +ORCHID_GENAI_API_KEYS_ENV = "ORCHID_GENAI_API_KEYS" + class ConfigError(Exception): """Raised when config operations fail""" pass @@ -17,6 +20,33 @@ def __init__(self, redis: Redis): self.last_load_time = 0 self.current_config = {} + def _load_api_keys_from_env(self) -> Optional[Dict[str, str]]: + """Load API keys from environment variable""" + api_keys_json = os.environ.get(ORCHID_GENAI_API_KEYS_ENV) + if not api_keys_json: + return None + + try: + api_keys = json.loads(api_keys_json) + if not isinstance(api_keys, dict): + logger.warning(f"Invalid API keys format in {ORCHID_GENAI_API_KEYS_ENV}: expected JSON object") + return None + return api_keys + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON from {ORCHID_GENAI_API_KEYS_ENV}") + return None + + def _apply_api_keys_to_config(self, config: Dict[str, Any], api_keys: Dict[str, str]) -> None: + """Apply API keys from environment to config endpoints""" + if 'inference' not in config or 'endpoints' not in config['inference']: + return + + endpoints = config['inference']['endpoints'] + for endpoint_id, endpoint in endpoints.items(): + if endpoint_id in api_keys: + endpoint['api_key'] = api_keys[endpoint_id] + logger.debug(f"Applied API key from environment to endpoint: {endpoint_id}") + async def load_from_file(self, config_path: str) -> Dict[str, Any]: try: with open(config_path, 'r') as f: @@ -40,6 +70,12 @@ def process_config(self, config: Dict[str, Any]) -> Dict[str, Any]: if not endpoints: raise ConfigError("No inference endpoints configured") + # Apply API keys from environment variable if it exists + env_api_keys = self._load_api_keys_from_env() + if env_api_keys: + logger.info("Found API keys in environment variables, overriding config values") + self._apply_api_keys_to_config(config, env_api_keys) + total_models = 0 for endpoint_id, endpoint in endpoints.items():