Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor I/O utils; allow 'task' command line parameter in cli.py #6187

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openhands/agenthub/micro/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.utils import json
from openhands.events.action import Action
from openhands.events.event import Event
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.event import event_to_memory
from openhands.io import json
from openhands.llm.llm import LLM


Expand Down
36 changes: 18 additions & 18 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FileEditObservation,
NullObservation,
)
from openhands.io import read_input, read_task


def display_message(message: str):
Expand Down Expand Up @@ -83,29 +84,21 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)


def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


async def main(loop: asyncio.AbstractEventLoop):
"""Runs the agent in CLI mode"""

args = parse_arguments()

logger.setLevel(logging.WARNING)

config = setup_config_from_args(args)
# Load config from toml and override with command line arguments
config: AppConfig = setup_config_from_args(args)

# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)

# If we have a task, create initial user action
initial_user_action = MessageAction(content=task_str) if task_str else None

sid = str(uuid4())

Expand All @@ -118,7 +111,9 @@ async def main(loop: asyncio.AbstractEventLoop):

async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
next_message = await loop.run_in_executor(None, read_input, config)
next_message = await loop.run_in_executor(
None, read_input, config.cli_multiline_input
)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
Expand Down Expand Up @@ -164,7 +159,12 @@ def on_event(event: Event) -> None:

await runtime.connect()

asyncio.create_task(prompt_for_next_task())
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
event_stream.add_event(initial_user_action, EventSource.USER)
else:
# Otherwise prompt for the user's first message right away
asyncio.create_task(prompt_for_next_task())

await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
Expand Down
52 changes: 10 additions & 42 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol

Expand Down Expand Up @@ -29,6 +28,7 @@
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.io import read_input, read_task
from openhands.runtime.base import Runtime


Expand All @@ -41,32 +41,6 @@ def __call__(
) -> str: ...


def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()


def read_task_from_stdin() -> str:
"""Read task from stdin."""
return sys.stdin.read()


def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


async def run_controller(
config: AppConfig,
initial_user_action: Action,
Expand Down Expand Up @@ -118,7 +92,6 @@ async def run_controller(
assert isinstance(
initial_user_action, Action
), f'initial user actions must be an Action, got {type(initial_user_action)}'
# Logging
logger.debug(
f'Agent Controller Initialized: Running agent {agent.name}, model '
f'{agent.llm.config.model}, with actions: {initial_user_action}'
Expand Down Expand Up @@ -146,7 +119,7 @@ def on_event(event: Event):
if exit_on_message:
message = '/exit'
elif fake_user_response_fn is None:
message = read_input(config)
message = read_input(config.cli_multiline_input)
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
Expand Down Expand Up @@ -243,28 +216,23 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
if __name__ == '__main__':
args = parse_arguments()

config = setup_config_from_args(args)
config: AppConfig = setup_config_from_args(args)

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()
# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)

initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:

if not task_str:
raise ValueError('No task provided. Please specify a task through -t, -f.')

# Create initial user action
initial_user_action: MessageAction = MessageAction(content=task_str)

# Set session name
session_name = args.name
sid = generate_sid(config, session_name)
Expand Down
Empty file removed openhands/core/utils/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import Callable, Iterable

from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.io import json
from openhands.storage import FileStore
from openhands.storage.locations import (
get_conversation_dir,
Expand Down
11 changes: 11 additions & 0 deletions openhands/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from openhands.io.io import read_input, read_task, read_task_from_file
from openhands.io.json import dumps, loads, my_default_encoder

__all__ = [
'read_input',
'read_task_from_file',
'read_task',
'my_default_encoder',
'dumps',
'loads',
]
40 changes: 40 additions & 0 deletions openhands/io/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import sys


def read_input(cli_multiline_input: bool = False) -> str:
"""Read input from user based on config settings."""
if cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()


def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str:
"""
Read the task from the CLI args, file, or stdin.
"""

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_input(cli_multiline_input)

return task_str
File renamed without changes.
4 changes: 2 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.core.utils import json
from openhands.io import json

messages: list[dict[str, Any]] | dict[str, Any] = []
mock_function_calling = kwargs.pop('mock_function_calling', False)
Expand Down Expand Up @@ -374,7 +374,7 @@ def init_model_info(self):
# noinspection PyBroadException
except Exception:
pass
from openhands.core.utils import json
from openhands.io import json

logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}')

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

from openhands.core.cli import read_input
from openhands.core.config import AppConfig
from openhands.io import read_input


def test_single_line_input():
Expand All @@ -10,7 +10,7 @@ def test_single_line_input():
config.cli_multiline_input = False

with patch('builtins.input', return_value='hello world'):
result = read_input(config)
result = read_input(config.cli_multiline_input)
assert result == 'hello world'


Expand All @@ -23,5 +23,5 @@ def test_multiline_input():
mock_inputs = ['line 1', 'line 2', 'line 3', '/exit']

with patch('builtins.input', side_effect=mock_inputs):
result = read_input(config)
result = read_input(config.cli_multiline_input)
assert result == 'line 1\nline 2\nline 3'
2 changes: 1 addition & 1 deletion tests/unit/test_json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime

from openhands.core.utils import json
from openhands.events.action import MessageAction
from openhands.io import json


def test_event_serialization_deserialization():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_response_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from openhands.agenthub.micro.agent import parse_response as parse_response_micro
from openhands.core.exceptions import LLMResponseError
from openhands.core.utils.json import loads as custom_loads
from openhands.events.action import (
FileWriteAction,
MessageAction,
)
from openhands.io import loads as custom_loads


@pytest.mark.parametrize(
Expand Down