Skip to content

Commit

Permalink
fix: cors + cors config
Browse files Browse the repository at this point in the history
  • Loading branch information
k11kirky committed Nov 25, 2024
1 parent 2653658 commit b504338
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ queue: # if using local task queue
max_workers: 10 # default
```
#### Using Environment Variables
Set the desired configuration options using environment variables.
Expand All @@ -226,6 +227,39 @@ export AGENTSERVE_REDIS_HOST=redis-server-host
export AGENTSERVE_REDIS_PORT=6379
```

### FastAPI Configuration

You can specify FastAPI settings, including CORS configuration, using the `fastapi` key in your `agentserve.yaml` configuration file.

**Example:**

```yaml
# agentserve.yaml
fastapi:
cors:
allow_origins:
- "http://localhost:3000"
- "https://yourdomain.com"
allow_credentials: true
allow_methods:
- "*"
allow_headers:
- "*"
```
#### Using Environment Variables

Alternatively, you can set the desired configuration options using environment variables.

**Example:**

```bash
export AGENTSERVE_CORS_ORIGINS="http://localhost:3000,https://yourdomain.com"
export AGENTSERVE_CORS_ALLOW_CREDENTIALS="true"
export AGENTSERVE_CORS_ALLOW_METHODS="GET,POST"
export AGENTSERVE_CORS_ALLOW_HEADERS="Content-Type,Authorization"
```

## Advanced Usage

### Integrating with Existing Projects
Expand Down
16 changes: 14 additions & 2 deletions agentserve/agent_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# agentserve/agent_server.py

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import ValidationError
from .queues.task_queue import TaskQueue
from .agent_registry import AgentRegistry
Expand All @@ -12,9 +13,20 @@
class AgentServer:
def __init__(self, config: Optional[Config] = None):
self.logger = setup_logger("agentserve.server")
self.app = FastAPI(debug=True)
self.agent_registry = AgentRegistry()
self.config = config or Config()
self.app = FastAPI()

# Add CORS middleware with custom origins
cors_config = self.config.get_nested('fastapi', 'cors', default={})
self.app.add_middleware(
CORSMiddleware,
allow_origins=cors_config.get('allow_origins', ["*"]),
allow_credentials=cors_config.get('allow_credentials', True),
allow_methods=cors_config.get('allow_methods', ["*"]),
allow_headers=cors_config.get('allow_headers', ["*"]),
)

self.agent_registry = AgentRegistry()
self.task_queue = self._initialize_task_queue()
self.agent = self.agent_registry.register_agent
self._setup_routes()
Expand Down
34 changes: 33 additions & 1 deletion agentserve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@ def _load_config(self):

# Override with environment variables
config['task_queue'] = os.getenv('AGENTSERVE_TASK_QUEUE', config.get('task_queue', 'local'))

# FastAPI configuration
fastapi_config = config.setdefault('fastapi', {})
# CORS configuration within FastAPI config
cors_config = fastapi_config.setdefault('cors', {})
cors_origins_env = os.getenv('AGENTSERVE_CORS_ORIGINS')
if cors_origins_env:
cors_config['allow_origins'] = [origin.strip() for origin in cors_origins_env.split(',')]
else:
cors_config.setdefault('allow_origins', ["*"])

cors_config['allow_credentials'] = self._get_bool_env(
'AGENTSERVE_CORS_ALLOW_CREDENTIALS',
cors_config.get('allow_credentials', True)
)
cors_methods_env = os.getenv('AGENTSERVE_CORS_ALLOW_METHODS')
if cors_methods_env:
cors_config['allow_methods'] = [method.strip() for method in cors_methods_env.split(',')]
else:
cors_config.setdefault('allow_methods', ["*"])

cors_headers_env = os.getenv('AGENTSERVE_CORS_ALLOW_HEADERS')
if cors_headers_env:
cors_config['allow_headers'] = [header.strip() for header in cors_headers_env.split(',')]
else:
cors_config.setdefault('allow_headers', ["*"])

# Celery configuration
celery_broker_url = os.getenv('AGENTSERVE_CELERY_BROKER_URL')
Expand Down Expand Up @@ -59,4 +85,10 @@ def get_nested(self, *keys, default=None):
value = value.get(key)
if value is None:
return default
return value
return value

def _get_bool_env(self, env_var, default):
val = os.getenv(env_var)
if val is not None:
return val.lower() in ('true', '1', 'yes')
return default

0 comments on commit b504338

Please sign in to comment.