Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup model routing config and plan routing to o1 #6189

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ codeact_enable_jupyter = true
# List of microagents to disable
#disabled_microagents = []

# Whether to enable plan routing to reasoning models
#enable_plan_routing = false

[agent.RepoExplorerAgent]
# Example: use a cheaper model for RepoExplorerAgent to reduce cost, especially
# useful when an agent doesn't demand high quality but uses a lot of tokens
Expand Down Expand Up @@ -276,6 +279,14 @@ llm_config = 'gpt3'
# The security analyzer to use (For Headless / CLI only - In Web this is overridden by Session Init)
#security_analyzer = ""

################################ Model Routing ###############################
# Configuration for model routing features
##############################################################################
[model_routing]

# The reasoning model to use for plan generation
reasoning_model = "o1-preview-2024-09-12"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reasoning_model = "o1-preview-2024-09-12"
[llm.reasoning_model]
model = "o1-preview-2024-09-12"
...

Copy link
Contributor Author

@ryanhoangt ryanhoangt Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is also another approach, my thought is for now we only use reasoning models specifically for model routing, so I put it in this config group (with other values in the future). When we also use them for other purposes, we can probably move to llm-specific groups?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that we can reuse the way we define a model (which will implicitly take care of the correct loading and init all base_url etc).

It doesn't say which component of openhands loads the definition of [llm.reasoning_model], it can be the routing component.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, if a user wants to use a reasoning model today, for the agent, they can do so. They just choose a reasoning model and configure it. Ability to use it isn't new?

We can just avoid to duplicate LLMConfig settings ("reasoning_model", "reasoning_model_base_url", "reasoning_model_api_key", "reasoning_model_aws...something" etc) into the new routing section, instead we can reference existing configurations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds good to me, thanks for the suggestion! I'll try to address this after getting the routing behavior to work


#################################### Eval ####################################
# Configuration for the evaluation, please refer to the specific evaluation
# plugin for the available options
Expand Down
1 change: 1 addition & 0 deletions evaluation/benchmarks/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_config(
codeact_enable_browsing=RUN_WITH_BROWSING,
codeact_enable_llm_editor=False,
condenser=metadata.condenser_config,
# enable_plan_routing=True,
)
config.set_agent_config(agent_config)
return config
Expand Down
21 changes: 18 additions & 3 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
from openhands.router.plan import LLMBasedPlanRouter
from openhands.runtime.plugins import (
AgentSkillsRequirement,
JupyterRequirement,
PluginRequirement,
)
from openhands.utils.prompt import PromptManager
from openhands.utils.trajectory import format_trajectory


class CodeActAgent(Agent):
Expand Down Expand Up @@ -120,6 +122,10 @@ def __init__(
self.condenser = Condenser.from_config(self.config.condenser)
logger.debug(f'Using condenser: {self.condenser}')

self.plan_router = (
LLMBasedPlanRouter(self.llm.config) if config.enable_plan_routing else None
)

def get_action_message(
self,
action: Action,
Expand Down Expand Up @@ -378,11 +384,20 @@ def step(self, state: State) -> Action:
if latest_user_message and latest_user_message.content.strip() == '/exit':
return AgentFinishAction()

params: dict = {}

# prepare what we want to send to the LLM
messages = self._get_messages(state)
params: dict = {
'messages': self.llm.format_messages_for_llm(messages),
}
params['messages'] = self.llm.format_messages_for_llm(messages)

# check if model routing is needed
if self.plan_router:
formatted_trajectory = format_trajectory(messages)

if self.plan_router.should_route_to_custom_model(formatted_trajectory):
logger.info('🧭 Routing to custom model...')
params['use_reasoning_model'] = True

params['tools'] = self.tools
if self.mock_function_calling:
params['mock_function_calling'] = True
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_field_info,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.core.config.utils import (
Expand All @@ -27,6 +28,7 @@
'LLMConfig',
'SandboxConfig',
'SecurityConfig',
'ModelRoutingConfig',
'load_app_config',
'load_from_env',
'load_from_toml',
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AgentConfig:
use_microagents: Whether to use microagents at all. Default is True.
disabled_microagents: A list of microagents to disable. Default is None.
condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig.
enable_plan_routing: Whether to enable plan routing to reasoning models. Default is False.
"""

codeact_enable_browsing: bool = True
Expand All @@ -32,6 +33,7 @@ class AgentConfig:
use_microagents: bool = True
disabled_microagents: list[str] | None = None
condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore
enable_plan_routing: bool = False

def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_field_info,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig

Expand Down Expand Up @@ -51,6 +52,7 @@ class AppConfig:
default_agent: str = OH_DEFAULT_AGENT
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
model_routing: ModelRoutingConfig = field(default_factory=ModelRoutingConfig)
runtime: str = 'docker'
file_store: str = 'local'
file_store_path: str = '/tmp/openhands_file_store'
Expand Down
32 changes: 32 additions & 0 deletions openhands/core/config/model_routing_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass, fields

from openhands.core.config.config_utils import get_field_info


@dataclass
class ModelRoutingConfig:
reasoning_model: str = 'o1-preview-2024-09-12'

def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict

def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)

attr_str.append(f'{attr_name}={repr(attr_value)}')

return f"ModelRoutingConfig({', '.join(attr_str)})"

@classmethod
def from_dict(cls, model_routing_config_dict: dict) -> 'ModelRoutingConfig':
return cls(**model_routing_config_dict)

def __repr__(self):
return self.__str__()
12 changes: 8 additions & 4 deletions openhands/core/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
from openhands.core import logger
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.app_config import AppConfig
from openhands.core.config.config_utils import (
OH_DEFAULT_AGENT,
OH_MAX_ITERATIONS,
)
from openhands.core.config.config_utils import OH_DEFAULT_AGENT, OH_MAX_ITERATIONS
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.model_routing_config import ModelRoutingConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
from openhands.storage import get_file_store
Expand Down Expand Up @@ -141,6 +139,12 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
)
agent_config = AgentConfig(**nested_value)
cfg.set_agent_config(agent_config, nested_key)
elif key is not None and key.lower() == 'model_routing':
logger.openhands_logger.debug(
'Attempt to load model routing config from config toml'
)
model_routing_config = ModelRoutingConfig.from_dict(value)
cfg.model_routing = model_routing_config
elif key is not None and key.lower() == 'llm':
logger.openhands_logger.debug(
'Attempt to load default LLM config from config toml'
Expand Down
7 changes: 3 additions & 4 deletions openhands/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import (
AppConfig,
)
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.llm.llm import LLM
Expand Down Expand Up @@ -61,8 +59,9 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
model_routing_config = config.model_routing
agent = agent_cls(
llm=LLM(config=llm_config),
llm=LLM(config=llm_config, model_routing_config=model_routing_config),
config=agent_config,
)
if agent.prompt_manager:
Expand Down
18 changes: 16 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import requests

from openhands.core.config import LLMConfig
from openhands.core.config import LLMConfig, ModelRoutingConfig

with warnings.catch_warnings():
warnings.simplefilter('ignore')
Expand Down Expand Up @@ -71,6 +71,7 @@
'claude-3-5-haiku-20241022',
'gpt-4o-mini',
'gpt-4o',
'o1',
]


Expand All @@ -85,6 +86,7 @@ def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
model_routing_config: ModelRoutingConfig | None = None,
):
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.

Expand All @@ -93,13 +95,15 @@ def __init__(
Args:
config: The LLM configuration.
metrics: The metrics to use.
model_routing_config: The model routing configuration.
"""
self._tried_model_info = False
self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model)
)
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.model_routing_config = model_routing_config

self.model_info: ModelInfo | None = None

Expand Down Expand Up @@ -158,6 +162,7 @@ def wrapper(*args, **kwargs):

messages: list[dict[str, Any]] | dict[str, Any] = []
mock_function_calling = kwargs.pop('mock_function_calling', False)
use_reasoning_model = kwargs.pop('use_reasoning_model', False)

# some callers might send the model and messages directly
# litellm allows positional args, like completion(model, messages, **kwargs)
Expand Down Expand Up @@ -189,6 +194,15 @@ def wrapper(*args, **kwargs):
kwargs['stop'] = STOP_WORDS
mock_fncall_tools = kwargs.pop('tools')

if use_reasoning_model:
if self.model_routing_config is None:
raise ValueError(
'Model routing config is required for model routing.'
)

# Replace the model with the reasoning model
kwargs['model'] = self.model_routing_config.reasoning_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is model enough, or also: custom provider, base URL?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could design the reasoning model not as a part of an LLM instance, but as a second LLM instance in the agent?

Copy link
Contributor Author

@ryanhoangt ryanhoangt Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is model enough, or also: custom provider, base URL?

Yeah, I think we also need to allow user to set these, especially if they don't use via a llm proxy 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using [llm.reasoning_model] will do it implicitly!


# if we have no messages, something went very wrong
if not messages:
raise ValueError(
Expand Down Expand Up @@ -636,7 +650,7 @@ def __str__(self):
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
return f'LLM(model={self.config.model},reasoning_model={self.model_routing_config.reasoning_model if self.model_routing_config else None})'

def __repr__(self):
return str(self)
Expand Down
7 changes: 7 additions & 0 deletions openhands/router/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod


class BaseRouter(ABC):
@abstractmethod
def should_route_to_custom_model(self, prompt: str) -> bool:
pass
4 changes: 4 additions & 0 deletions openhands/router/plan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from openhands.router.plan.llm_based import LLMBasedPlanRouter
from openhands.router.plan.rule_based import RuleBasedPlanRouter

__all__ = ['RuleBasedPlanRouter', 'LLMBasedPlanRouter']
43 changes: 43 additions & 0 deletions openhands/router/plan/llm_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import copy

from openhands.core.config import LLMConfig
from openhands.llm.llm import LLM
from openhands.router.base import BaseRouter
from openhands.router.plan.prompts import (
TRAJECTORY_JUDGE_REASONING_SYSTEM_PROMPT,
TRAJECTORY_JUDGE_REASONING_USER_PROMPT,
)


class LLMBasedPlanRouter(BaseRouter):
"""
Router that routes the prompt that is judged by a LLM as complex and requires a step-by-step plan.
"""

JUDGE_MODEL = 'gpt-4o'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be interesting to see if we can experiment with cheaper model for that 🤔


def __init__(self, llm_config: LLMConfig):
super().__init__()

judge_llm_config = copy.deepcopy(llm_config)
self.judge_llm = LLM(judge_llm_config)

def should_route_to_custom_model(self, prompt: str) -> bool:
messages = [
{
'role': 'system',
'content': TRAJECTORY_JUDGE_REASONING_SYSTEM_PROMPT,
},
{
'role': 'user',
'content': TRAJECTORY_JUDGE_REASONING_USER_PROMPT.format(
interaction_log=prompt
),
},
]

response = self.judge_llm.completion(
messages=messages,
model=self.JUDGE_MODEL,
)
return int(response['choices'][0]['message']['content'].strip()) == 1
Loading
Loading