From 8c34134f4bcc4630276b20dc238c808371dde7de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albert=20=C3=96rwall?= Date: Sat, 1 Feb 2025 20:29:47 +0100 Subject: [PATCH] Make async --- moatless-ui/src/pages/runs/[id].tsx | 38 +-- .../src/pages/runs/components/RunEvents.tsx | 57 +++-- moatless-ui/src/pages/validate/index.tsx | 28 ++- moatless/actions/action.py | 4 +- moatless/actions/append_string.py | 2 +- moatless/actions/claude_text_editor.py | 2 +- moatless/actions/create_file.py | 2 +- moatless/actions/finish.py | 2 +- moatless/actions/identify_mixin.py | 4 +- moatless/actions/insert_line.py | 2 +- moatless/actions/list_files.py | 2 +- moatless/actions/reject.py | 2 +- moatless/actions/respond.py | 2 +- moatless/actions/run_tests.py | 2 +- moatless/actions/search_base.py | 8 +- moatless/actions/string_replace.py | 2 +- moatless/actions/verified_finish.py | 2 +- moatless/actions/view_code.py | 4 +- moatless/actions/view_diff.py | 2 +- moatless/agent/agent.py | 29 +-- moatless/agent/events.py | 8 +- moatless/agentic_system.py | 23 +- moatless/api/api.py | 79 +++--- moatless/api/runs/api.py | 20 +- moatless/api/runs/schema.py | 19 +- moatless/api/swebench/api.py | 27 +- moatless/cli.py | 7 +- moatless/completion/anthropic.py | 231 ------------------ moatless/completion/base.py | 21 +- moatless/events.py | 33 +-- moatless/loop.py | 4 +- moatless/runner.py | 22 +- moatless/search_tree.py | 2 +- moatless/validation/code_flow_validation.py | 4 +- 34 files changed, 212 insertions(+), 484 deletions(-) delete mode 100644 moatless/completion/anthropic.py diff --git a/moatless-ui/src/pages/runs/[id].tsx b/moatless-ui/src/pages/runs/[id].tsx index e40ea963..c183ac9c 100644 --- a/moatless-ui/src/pages/runs/[id].tsx +++ b/moatless-ui/src/pages/runs/[id].tsx @@ -7,29 +7,27 @@ import { Alert, AlertDescription } from '@/lib/components/ui/alert'; import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '@/lib/components/ui/resizable'; import { RunStatus } from './components/RunStatus'; import { RunEvents } from './components/RunEvents'; -import { useWebSocketStore } from '@/lib/stores/websocketStore'; -import { useMemo } from 'react'; +import { useEffect } from 'react'; import { ScrollArea } from '@/lib/components/ui/scroll-area'; import { TimelineItemDetails } from './components/TimelineItemDetails'; +import { useQueryClient } from '@tanstack/react-query'; +import { useWebSocketStore } from '@/lib/stores/websocketStore'; export function RunPage() { const { id } = useParams<{ id: string }>(); const { data: runData, isError, error } = useRun(id!); + const queryClient = useQueryClient(); + const { subscribe } = useWebSocketStore(); - // Use useMemo to cache the selector functions - //const selectMessages = useMemo( - // () => (state: any) => state.messages, - // [id] - //); - - const selectConnectionStatus = useMemo( - () => (state: any) => state.connectionStatus, - [] - ); + useEffect(() => { + if (!id) return; + + const unsubscribe = subscribe(`run.${id}`, () => { + queryClient.invalidateQueries({ queryKey: ['run', id] }); + }); - // Use the memoized selectors - //const messages = useWebSocketStore(selectMessages); - const wsStatus = useWebSocketStore(selectConnectionStatus); + return () => unsubscribe(); + }, [id, subscribe, queryClient]); if (isError) { return ( @@ -59,16 +57,6 @@ export function RunPage() { ); } - // Combine WebSocket messages with initial events - //const allEvents = useMemo(() => { - // const wsEvents = messages.map(msg => ({ - // event_type: msg.type, - // timestamp: new Date().toISOString(), - // data: { message: msg.message || msg.error } - // })); - // return [...(runData.events || []), ...wsEvents]; - //}, [runData.events, messages]); - return (
b.timestamp - a.timestamp); +export function RunEvents({ events, className }: RunEventsProps) { + // just sort in reverse order + const reversedEvents = [...events].reverse(); const getEventIcon = (type: string) => { switch (type) { case 'error': return ; case 'agent_message': - return ; + return ; case 'system_message': - return ; + return ; case 'user_message': - return ; + return ; + case 'agent_action_created': + return ; + case 'agent_action_executed': + return ; default: - return ; + return ; } }; + const formatEventType = (type: string) => { + return type + .split('_') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(' '); + }; + + if (events.length === 0) { + return ( +
+ No events yet +
+ ); + } + return (
- {sortedEvents.map((event, i) => ( + {reversedEvents.map((event, i) => (
{getEventIcon(event.event_type)} - {event.event_type.replace(/_/g, ' ')} + {formatEventType(event.event_type)} + {event.node_id !== undefined && ( + + Node {event.node_id} + + )}
{formatDistanceToNow(new Date(event.timestamp))} ago
- {event.data.message && ( -

- {event.data.message} + {event.action_name && ( +

+ Action: {event.action_name}

)}
diff --git a/moatless-ui/src/pages/validate/index.tsx b/moatless-ui/src/pages/validate/index.tsx index ead4f150..ed91c62e 100644 --- a/moatless-ui/src/pages/validate/index.tsx +++ b/moatless-ui/src/pages/validate/index.tsx @@ -1,4 +1,4 @@ -import { useState } from 'react'; +import { useState, useEffect } from 'react'; import { Button } from '@/lib/components/ui/button'; import { Loader2 } from 'lucide-react'; import { toast } from 'sonner'; @@ -10,6 +10,8 @@ import { ModelSelector } from '@/lib/components/selectors/ModelSelector'; import { useValidationStore } from '@/stores/validationStore'; import { AlertCircle } from 'lucide-react'; import { Alert, AlertDescription, AlertTitle } from '@/lib/components/ui/alert'; +import { useWebSocket } from '@/lib/stores/websocketStore'; +import { useQueryClient } from '@tanstack/react-query'; export function ValidatePage() { const navigate = useNavigate(); @@ -31,9 +33,12 @@ export function ValidatePage() { // Add error state const [error, setError] = useState(null); + const queryClient = useQueryClient(); + const { subscribe, unsubscribe } = useWebSocket(); + const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); - setError(null); // Clear any previous errors + setError(null); try { const response = await startValidation.mutateAsync({ @@ -56,14 +61,6 @@ export function ValidatePage() { } }; - if (startValidation.isLoading) { - return ( -
- -
- ); - } - return (

Validate Agent

@@ -97,9 +94,16 @@ export function ValidatePage() {
diff --git a/moatless/actions/action.py b/moatless/actions/action.py index c0780f74..48630642 100644 --- a/moatless/actions/action.py +++ b/moatless/actions/action.py @@ -56,7 +56,7 @@ class Action(BaseModel, ABC): _workspace: Workspace = PrivateAttr(default=None) - def execute(self, args: ActionArguments, file_context: FileContext | None = None) -> Observation: + async def execute(self, args: ActionArguments, file_context: FileContext | None = None) -> Observation: """ Execute the action. """ @@ -67,7 +67,7 @@ def execute(self, args: ActionArguments, file_context: FileContext | None = None message = self._execute(args, file_context=file_context) return Observation.create(message) - def _execute(self, args: ActionArguments, file_context: FileContext | None = None) -> str | None: + async def _execute(self, args: ActionArguments, file_context: FileContext | None = None) -> str | None: """ Execute the action and return the updated FileContext. """ diff --git a/moatless/actions/append_string.py b/moatless/actions/append_string.py index 84e40aa5..e5b0b42a 100644 --- a/moatless/actions/append_string.py +++ b/moatless/actions/append_string.py @@ -62,7 +62,7 @@ class AppendString(Action, CodeActionValueMixin, CodeModificationMixin): args_schema = AppendStringArgs - def execute( + async def execute( self, args: AppendStringArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/claude_text_editor.py b/moatless/actions/claude_text_editor.py index 6e0fb532..bfe3f4d1 100644 --- a/moatless/actions/claude_text_editor.py +++ b/moatless/actions/claude_text_editor.py @@ -153,7 +153,7 @@ class ClaudeEditTool(Action, CodeModificationMixin): _create_file: CreateFile = PrivateAttr() _repository: Repository | None = PrivateAttr(None) - def execute( + async def execute( self, args: EditActionArguments, file_context: FileContext | None = None, diff --git a/moatless/actions/create_file.py b/moatless/actions/create_file.py index b7ce82f1..4a03ce23 100644 --- a/moatless/actions/create_file.py +++ b/moatless/actions/create_file.py @@ -51,7 +51,7 @@ class CreateFile(Action, CodeActionValueMixin, CodeModificationMixin): args_schema = CreateFileArgs - def execute( + async def execute( self, args: CreateFileArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/finish.py b/moatless/actions/finish.py index 93af45f6..27a7005d 100644 --- a/moatless/actions/finish.py +++ b/moatless/actions/finish.py @@ -46,7 +46,7 @@ class Finish(Action): description="Whether to enforce that the file context has a test patch", ) - def execute( + async def execute( self, args: FinishArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/identify_mixin.py b/moatless/actions/identify_mixin.py index a2381360..beaf0c3f 100644 --- a/moatless/actions/identify_mixin.py +++ b/moatless/actions/identify_mixin.py @@ -73,7 +73,7 @@ class IdentifyMixin(CompletionModelMixin): description="The maximum number of tokens allowed in the identify prompt.", ) - def _identify_code(self, args, view_context: FileContext, max_tokens: int) -> Tuple[FileContext, Completion]: + async def _identify_code(self, args, view_context: FileContext, max_tokens: int) -> Tuple[FileContext, Completion]: """Identify relevant code sections in a large context. Args: @@ -106,7 +106,7 @@ def _identify_code(self, args, view_context: FileContext, max_tokens: int) -> Tu MAX_RETRIES = 3 for retry_attempt in range(MAX_RETRIES): - completion_response = self._completion_model.create_completion(messages=messages) + completion_response = await self._completion_model.create_completion(messages=messages) logger.info( f"Identifying relevant code sections. Attempt {retry_attempt + 1} of {MAX_RETRIES}.{len(completion_response.structured_outputs)} identify requests." ) diff --git a/moatless/actions/insert_line.py b/moatless/actions/insert_line.py index d9813ac8..be5f9b30 100644 --- a/moatless/actions/insert_line.py +++ b/moatless/actions/insert_line.py @@ -55,7 +55,7 @@ class InsertLine(Action, CodeActionValueMixin, CodeModificationMixin): args_schema = InsertLinesArgs - def execute( + async def execute( self, args: InsertLinesArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/list_files.py b/moatless/actions/list_files.py index 27f29acf..8814122f 100644 --- a/moatless/actions/list_files.py +++ b/moatless/actions/list_files.py @@ -33,7 +33,7 @@ def short_summary(self) -> str: class ListFiles(Action): args_schema = ListFilesArgs - def execute( + async def execute( self, args: ListFilesArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/reject.py b/moatless/actions/reject.py index e4379856..e5e4366f 100644 --- a/moatless/actions/reject.py +++ b/moatless/actions/reject.py @@ -25,5 +25,5 @@ def equals(self, other: "ActionArguments") -> bool: class Reject(Action): args_schema: ClassVar[Type[ActionArguments]] = RejectArgs - def execute(self, args: RejectArgs, file_context: FileContext | None = None): + async def execute(self, args: RejectArgs, file_context: FileContext | None = None): return Observation(message=args.rejection_reason, terminal=True) diff --git a/moatless/actions/respond.py b/moatless/actions/respond.py index 0ad883ef..c3684dcd 100644 --- a/moatless/actions/respond.py +++ b/moatless/actions/respond.py @@ -27,7 +27,7 @@ class MessageAction(Action): args_schema = MessageArgs - def execute( + async def execute( self, args: MessageArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/run_tests.py b/moatless/actions/run_tests.py index a4e65d13..8593c600 100644 --- a/moatless/actions/run_tests.py +++ b/moatless/actions/run_tests.py @@ -47,7 +47,7 @@ class RunTests(Action): _repository: Repository = PrivateAttr() _runtime: RuntimeEnvironment = PrivateAttr() - def execute( + async def execute( self, args: RunTestsArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/search_base.py b/moatless/actions/search_base.py index 2998754f..91400852 100644 --- a/moatless/actions/search_base.py +++ b/moatless/actions/search_base.py @@ -106,7 +106,7 @@ def _initialize_completion_model(self): if self._completion_model: self._completion_model.initialize(Identify, IDENTIFY_SYSTEM_PROMPT) - def execute(self, args: SearchBaseArgs, file_context: FileContext | None = None) -> Observation: + async def execute(self, args: SearchBaseArgs, file_context: FileContext | None = None) -> Observation: if file_context is None: raise ValueError("File context must be provided to execute the search action.") @@ -141,7 +141,7 @@ def execute(self, args: SearchBaseArgs, file_context: FileContext | None = None) logger.info( f"{self.name}: Search too large. {properties['search_tokens']} tokens and {search_result_context.span_count()} hits, will ask for clarification." ) - view_context, completion = self._identify_code(args, search_result_context) + view_context, completion = await self._identify_code(args, search_result_context) else: view_context = search_result_context @@ -239,7 +239,7 @@ def _search(self, args: SearchBaseArgs) -> SearchCodeResponse: def _search_for_alternative_suggestion(self, args: SearchBaseArgs) -> SearchCodeResponse: return SearchCodeResponse() - def _identify_code( + async def _identify_code( self, args: SearchBaseArgs, search_result_ctx: FileContext ) -> Tuple[IdentifiedSpans, Completion]: search_result_str = search_result_ctx.create_prompt( @@ -263,7 +263,7 @@ def _identify_code( MAX_RETRIES = 3 for retry_attempt in range(MAX_RETRIES): - completion_response = self.completion_model.create_completion(messages=messages) + completion_response = await self.completion_model.create_completion(messages=messages) logger.info( f"Identifying relevant code sections. Attempt {retry_attempt + 1} of {MAX_RETRIES}.{len(completion_response.structured_outputs)} identify requests." ) diff --git a/moatless/actions/string_replace.py b/moatless/actions/string_replace.py index a5cca39b..f3d0b8da 100644 --- a/moatless/actions/string_replace.py +++ b/moatless/actions/string_replace.py @@ -132,7 +132,7 @@ class StringReplace(Action, CodeActionValueMixin, CodeModificationMixin): description="When True, automatically corrects indentation if all lines have the same indentation difference", ) - def execute(self, args: StringReplaceArgs, file_context: FileContext | None = None) -> Observation: + async def execute(self, args: StringReplaceArgs, file_context: FileContext | None = None) -> Observation: path_str = self.normalize_path(args.path) path, error = self.validate_file_access(path_str, file_context) if error: diff --git a/moatless/actions/verified_finish.py b/moatless/actions/verified_finish.py index 82189d1e..776155de 100644 --- a/moatless/actions/verified_finish.py +++ b/moatless/actions/verified_finish.py @@ -43,7 +43,7 @@ def equals(self, other: "ActionArguments") -> bool: class VerifiedFinish(Action): args_schema: ClassVar[Type[ActionArguments]] = VerifiedFinishArgs - def execute( + async def execute( self, args: VerifiedFinishArgs, file_context: FileContext | None = None, diff --git a/moatless/actions/view_code.py b/moatless/actions/view_code.py index 4ea6e4ac..015ce774 100644 --- a/moatless/actions/view_code.py +++ b/moatless/actions/view_code.py @@ -88,7 +88,7 @@ class ViewCode(Action, IdentifyMixin): description="The maximum number of tokens in the requested code.", ) - def execute( + async def execute( self, args: ViewCodeArgs, file_context: FileContext | None = None, @@ -179,7 +179,7 @@ def execute( view_file.set_patch(file.patch) if view_context.context_size() > self.max_tokens: - view_context, completion = self._identify_code(args, view_context, self.max_tokens) + view_context, completion = await self._identify_code(args, view_context, self.max_tokens) new_span_ids = file_context.add_file_context(view_context) properties["files"][file_path] = { diff --git a/moatless/actions/view_diff.py b/moatless/actions/view_diff.py index d4d8367a..580876a6 100644 --- a/moatless/actions/view_diff.py +++ b/moatless/actions/view_diff.py @@ -30,7 +30,7 @@ class ViewDiff(Action): args_schema = ViewDiffArgs - def execute(self, args: ViewDiffArgs, file_context: FileContext | None = None) -> Observation: + async def execute(self, args: ViewDiffArgs, file_context: FileContext | None = None) -> Observation: diff = file_context.generate_git_patch() if not diff: diff --git a/moatless/agent/agent.py b/moatless/agent/agent.py index 46b8a622..73443a64 100644 --- a/moatless/agent/agent.py +++ b/moatless/agent/agent.py @@ -134,19 +134,16 @@ def set_event_handler(self, handler: Callable): """Set the event handler for agent events""" self._event_handler = handler - def _emit_event(self, event: AgentEvent): + async def _emit_event(self, event: AgentEvent): """Emit a pure agent event""" if self._event_handler: - self._event_handler(event) + await self._event_handler(event) - def run(self, node: Node): + async def run(self, node: Node): """Run the agent on a node to generate and execute an action.""" if not self._completion_model: raise RuntimeError("Completion model not set") - # Emit agent started event - self._emit_event(AgentStarted(agent_id=self.agent_id, node_id=node.node_id)) - if node.action: logger.info(f"Node{node.node_id}: Resetting node") node.reset() @@ -158,7 +155,7 @@ def run(self, node: Node): messages = self._message_generator.generate_messages(node) logger.info(f"Node{node.node_id}: Build action with {len(messages)} messages") - completion_response = self._completion_model.create_completion( + completion_response = await self._completion_model.create_completion( messages=messages, ) node.completions["build_action"] = completion_response.completion @@ -171,12 +168,11 @@ def run(self, node: Node): node.action_steps = [ActionStep(action=action) for action in completion_response.structured_outputs] # Emit action created events for step in node.action_steps: - self._emit_event( + await self._emit_event( AgentActionCreated( agent_id=self.agent_id, node_id=node.node_id, action_name=step.action.name, - action_params=step.action.model_dump(exclude={"name"}), ) ) @@ -218,9 +214,9 @@ def run(self, node: Node): action_names = [action_step.action.name for action_step in node.action_steps] logger.info(f"Node{node.node_id}: Execute actions: {action_names}") for action_step in node.action_steps: - self._execute(node, action_step) + await self._execute(node, action_step) - def _execute(self, node: Node, action_step: ActionStep): + async def _execute(self, node: Node, action_step: ActionStep): action = self.action_map.get(type(action_step.action)) if not action: logger.error( @@ -232,15 +228,14 @@ def _execute(self, node: Node, action_step: ActionStep): ) try: - action_step.observation = action.execute(action_step.action, file_context=node.file_context) + action_step.observation = await action.execute(action_step.action, file_context=node.file_context) # Emit action executed event - self._emit_event( + await self._emit_event( AgentActionExecuted( agent_id=self.agent_id, node_id=node.node_id, action_name=action_step.action.name, - observation=action_step.observation.message if action_step.observation else None, ) ) @@ -253,8 +248,10 @@ def _execute(self, node: Node, action_step: ActionStep): logger.info( f"Executed action: {action_step.action.name}. " - f"Terminal: {action_step.observation.terminal if node.observation else False}. " - f"Output: {action_step.observation.message if node.observation else None}" + f"Terminal: {action_step.observation.terminal if node.observation else False}. ") + + logger.debug( + f"Observation: {action_step.observation.message if node.observation else None}" ) except CompletionRejectError as e: diff --git a/moatless/agent/events.py b/moatless/agent/events.py index 581a85fe..2338e408 100644 --- a/moatless/agent/events.py +++ b/moatless/agent/events.py @@ -7,9 +7,11 @@ class AgentEvent(BaseEvent): """Base class for pure agent events""" - + agent_id: str node_id: int + action_name: Optional[str] = None + action_params: Optional[Dict] = None class AgentStarted(AgentEvent): @@ -22,13 +24,9 @@ class AgentActionCreated(AgentEvent): """Emitted when an agent creates an action""" event_type: str = "agent_action_created" - action_name: str - action_params: Dict class AgentActionExecuted(AgentEvent): """Emitted when an agent executes an action""" event_type: str = "agent_action_executed" - action_name: str - observation: Optional[str] diff --git a/moatless/agentic_system.py b/moatless/agentic_system.py index e32ee4b5..7d2b300a 100644 --- a/moatless/agentic_system.py +++ b/moatless/agentic_system.py @@ -151,12 +151,12 @@ def create( **kwargs, ) - def run(self) -> Node: + async def run(self) -> Node: """Run the system with optional root node.""" try: self._initialize_run_state() self.agent.set_event_handler(self._handle_agent_event) - result = self._run() + result = await self._run() # Complete attempt successfully self._status.complete_current_attempt("completed") @@ -180,18 +180,18 @@ def run(self) -> Node: self.agent.remove_event_handler() @abstractmethod - def _run(self) -> Node: + async def _run(self) -> Node: raise NotImplementedError("Subclass must implement _run method") - def emit_event(self, event: BaseEvent): + async def emit_event(self, event: BaseEvent): """Emit an event.""" logger.info(f"Emit event {event.event_type}") self._save_event(self.run_id, event) - event_bus.publish(self.run_id, event) + await event_bus.publish(self.run_id, event) - def _handle_agent_event(self, event: BaseEvent): + async def _handle_agent_event(self, event: BaseEvent): """Handle agent events and propagate them to system event handlers""" - self.emit_event(event) + await self.emit_event(event) def _initialize_run_state(self): """Initialize or restore system run state and logging""" @@ -272,17 +272,10 @@ def _save_event(self, run_id: str, event: BaseEvent): event_data = None try: - current_attempt = self._status.get_current_attempt() - attempt_id = current_attempt.attempt_id if current_attempt else None - event_data = { 'timestamp': datetime.now(timezone.utc).isoformat(), 'run_id': run_id, - 'event_type': event.event_type, - 'node_id': event.model_dump().get('node_id'), - 'data': event.model_dump(exclude={'event_type', 'node_id'}), - 'restart_count': self._status.restart_count, - 'attempt_id': attempt_id + **event.model_dump(), } self._events_file.write(json.dumps(event_data) + '\n') self._events_file.flush() diff --git a/moatless/api/api.py b/moatless/api/api.py index 647ffed1..af9a2781 100644 --- a/moatless/api/api.py +++ b/moatless/api/api.py @@ -39,31 +39,42 @@ def __init__(self): self.active_connections: Set[WebSocket] = set() async def connect(self, websocket: WebSocket): - logger.info("Connecting to WebSocket") - await websocket.accept() - self.active_connections.add(websocket) - logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}") + try: + logger.info("Accepting WebSocket connection") + await websocket.accept() + self.active_connections.add(websocket) + logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}") + except Exception as e: + logger.error(f"Failed to accept WebSocket connection: {e}") + raise def disconnect(self, websocket: WebSocket): - self.active_connections.discard(websocket) - logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}") + try: + self.active_connections.discard(websocket) + logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}") + except Exception as e: + logger.error(f"Error during WebSocket disconnect: {e}") async def broadcast_message(self, message: dict): """Broadcast message to all connected clients""" + if not self.active_connections: + return + logger.info(f"Broadcasting message to {len(self.active_connections)} clients") + + connections = self.active_connections.copy() disconnected = set() - for connection in self.active_connections: + + for connection in connections: try: await connection.send_text(json.dumps(message)) - except WebSocketDisconnect: + except Exception as e: + logger.error(f"Failed to send message to client: {e}") disconnected.add(connection) - # Clean up disconnected clients for connection in disconnected: self.disconnect(connection) - -# Create a global connection manager instance manager = ConnectionManager() @@ -71,65 +82,48 @@ async def handle_system_event(run_id: str, event: dict): """Handle system events and broadcast them via WebSocket""" message = { 'run_id': run_id, - 'type': event.get('event_type'), - **event + **event.model_dump(exclude_none=True) } await manager.broadcast_message(message) def create_api(workspace: Workspace | None = None) -> FastAPI: """Create and initialize the API with an optional workspace""" - # Load environment variables + load_dotenv() - # Create main FastAPI application api = FastAPI(title="Moatless API") - - # Update CORS middleware with WebSocket origins - origins = [ - "http://localhost:5173", - "http://127.0.0.1:5173", - "http://[::1]:5173", - "http://localhost:4173", - "http://127.0.0.1:4173", - "http://[::1]:4173", - # Add WebSocket origins - "ws://localhost:5173", - "ws://127.0.0.1:5173", - "ws://[::1]:5173", - # Add development API server origins - "ws://localhost:8000", - "ws://127.0.0.1:8000", - "http://localhost:8000", - "http://127.0.0.1:8000" - ] - api.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allow all origins in development + allow_origins=["*"], allow_credentials=True, - allow_methods=["*"], # Allow all methods + allow_methods=["*"], allow_headers=["*"], max_age=3600, ) - # Create API router with /api prefix router = FastAPI(title="Moatless API") @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): try: await manager.connect(websocket) - # Keep connection alive with ping/pong + while True: try: - data = await websocket.receive_text() - # Handle incoming messages if needed + # Wait for messages but don't do anything with them + # This keeps the connection alive + await websocket.receive_text() except WebSocketDisconnect: break + except Exception as e: + logger.error(f"Error in WebSocket connection: {e}") + break + except Exception as e: + logger.error(f"Failed to establish WebSocket connection: {e}") + raise finally: manager.disconnect(websocket) - logger.info("WebSocket connection closed") if workspace is not None: @@ -226,6 +220,7 @@ async def serve_spa(request: Request, full_path: str): except ImportError: logger.info("API extras not installed, UI will not be served") + logger.info("Subscribing to system events") # Subscribe to system events event_bus.subscribe(handle_system_event) diff --git a/moatless/api/runs/api.py b/moatless/api/runs/api.py index 7245526b..8f3c80f8 100644 --- a/moatless/api/runs/api.py +++ b/moatless/api/runs/api.py @@ -3,11 +3,13 @@ import os import json import logging +import time from fastapi import APIRouter, HTTPException +from moatless.agentic_system import SystemStatus from moatless.api.trajectory.schema import TrajectoryDTO from moatless.runner import agentic_runner from moatless.api.trajectory.trajectory_utils import convert_nodes, create_trajectory_dto, load_trajectory_from_file -from .schema import RunResponseDTO, RunStatusDTO, RunEventDTO +from .schema import RunResponseDTO, RunEventDTO from pathlib import Path from datetime import datetime @@ -43,7 +45,7 @@ def load_run_events(run_dir: Path) -> list[RunEventDTO]: return events -def load_run_status(run_dir: Path) -> RunStatusDTO: +def load_run_status(run_dir: Path) -> SystemStatus: """Load status from status.json file.""" status_path = run_dir / 'status.json' @@ -55,26 +57,24 @@ def load_run_status(run_dir: Path) -> RunStatusDTO: for dt_field in ['started_at', 'finished_at', 'last_restart']: if status_data.get(dt_field): status_data[dt_field] = datetime.fromisoformat(status_data[dt_field]) - return RunStatusDTO(**status_data) + return SystemStatus(**status_data) except Exception as e: logger.error(f"Error reading status file: {e}") # Return default status if file doesn't exist or has errors - return RunStatusDTO( + return SystemStatus( status="unknown", started_at=datetime.utcnow() ) @router.get("/{run_id}", response_model=RunResponseDTO) async def get_run(run_id: str): - logger.info(f"Getting run {run_id}") """Get the status, trajectory data, and events for a specific run.""" try: - run_dir = get_run_dir(run_id) # First try to get active run from runner - system = agentic_runner.get_run(run_id) - + system = await agentic_runner.get_run(run_id) + run_dir = get_run_dir(run_id) if system: # Active run found - get status and trajectory from system status = "running" @@ -92,8 +92,6 @@ async def get_run(run_id: str): system_status = system.get_status() else: - logger.info(f"Run {run_id} not found in runner, trying to load from file") - # Try to load completed run from trajectory file trajectory_path = get_trajectory_path(run_id) try: trajectory = load_trajectory_from_file(trajectory_path) @@ -101,7 +99,7 @@ async def get_run(run_id: str): except FileNotFoundError: raise HTTPException(status_code=404, detail="Run not found") - # Load status from file + system_status = load_run_status(run_dir) if system_status.status == "running": system_status.status = "stopped" diff --git a/moatless/api/runs/schema.py b/moatless/api/runs/schema.py index 2c880a78..de128eff 100644 --- a/moatless/api/runs/schema.py +++ b/moatless/api/runs/schema.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from typing import Optional, List, Dict, Any from datetime import datetime +from moatless.agentic_system import SystemStatus from moatless.api.trajectory.schema import TrajectoryDTO class RunEventDTO(BaseModel): @@ -8,24 +9,12 @@ class RunEventDTO(BaseModel): timestamp: int # Changed from datetime to int (milliseconds) event_type: str node_id: Optional[int] = None - data: Dict[str, Any] - attempt_id: Optional[int] = None - restart_count: Optional[int] = None - -class RunStatusDTO(BaseModel): - """Status of a run.""" - status: str = "running" - error: Optional[str] = None - started_at: datetime - finished_at: Optional[datetime] = None - restart_count: int = 0 - last_restart: Optional[datetime] = None - metadata: Dict[str, Any] = {} - current_attempt: Optional[int] = None + agent_id: Optional[str] = None + action_name: Optional[str] = None class RunResponseDTO(BaseModel): """Response containing run status, trajectory data, and events.""" status: str - system_status: RunStatusDTO + system_status: SystemStatus trajectory: Optional[TrajectoryDTO] = None events: List[RunEventDTO] = [] \ No newline at end of file diff --git a/moatless/api/swebench/api.py b/moatless/api/swebench/api.py index 14872ed8..bda458f1 100644 --- a/moatless/api/swebench/api.py +++ b/moatless/api/swebench/api.py @@ -70,22 +70,17 @@ async def validate_instance(request: SWEBenchValidationRequestDTO): # Initialize the validator validator = CodeFlowValidation() - # Start the validation in background - async def run_validation(): - try: - validator.start_code_loop( - run_id=run_id, - agent_id=request.agent_id, - model_id=request.model_id, - instance_id=request.instance_id, - max_iterations=request.max_iterations - ) - except Exception as e: - logger.exception(f"Validation failed: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - # Start validation in background task - asyncio.create_task(run_validation()) + try: + await validator.start_code_loop( + run_id=run_id, + agent_id=request.agent_id, + model_id=request.model_id, + instance_id=request.instance_id, + max_iterations=request.max_iterations + ) + except Exception as e: + logger.exception(f"Validation failed: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) return SWEBenchValidationResponseDTO( run_id=run_id diff --git a/moatless/cli.py b/moatless/cli.py index f0426c36..9d94578b 100644 --- a/moatless/cli.py +++ b/moatless/cli.py @@ -99,9 +99,8 @@ def __init__(self): ) self.events_file = None event_bus.subscribe(self.handle_event_sync) - # event_bus.subscribe(self.handle_event) # Keep async handler - def handle_event_sync(self, run_id: str, event: BaseEvent): + async def handle_event_sync(self, run_id: str, event: BaseEvent): """Synchronous event handler for immediate console output""" # Get event data event_data = event.model_dump() @@ -143,8 +142,6 @@ def handle_event_sync(self, run_id: str, event: BaseEvent): if event_data: self.console.print(event_data, style="bright_black") - async def handle_event(self, run_id: str, event: BaseEvent): - """Async event handler for file writing and state updates""" if run_id not in self.active_runs: self.active_runs[run_id] = { 'status': 'initializing', @@ -204,7 +201,7 @@ async def run_validation(self, try: # Start the code loop with our run_id - self.validator.start_code_loop( + await self.validator.start_code_loop( run_id=run_id, agent_id=agent_id, model_id=model_id, diff --git a/moatless/completion/anthropic.py b/moatless/completion/anthropic.py deleted file mode 100644 index bda25822..00000000 --- a/moatless/completion/anthropic.py +++ /dev/null @@ -1,231 +0,0 @@ -import json -import logging -from typing import Optional, Union, List - -import anthropic -import tenacity -from anthropic import Anthropic, AnthropicBedrock, NOT_GIVEN -from anthropic.types import ToolUseBlock, TextBlock -from anthropic.types.beta import ( - BetaToolUseBlock, - BetaTextBlock, -) -from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt -from pydantic import Field, ValidationError - -from moatless.completion import CompletionModel -from moatless.completion.completion import LLMResponseFormat, CompletionResponse -from moatless.completion.model import Completion, StructuredOutput, Usage -from moatless.exceptions import CompletionRejectError, CompletionRuntimeError - -logger = logging.getLogger(__name__) - - -class AnthtropicCompletionModel(CompletionModel): - response_format: Optional[LLMResponseFormat] = Field( - LLMResponseFormat.TOOLS, description="The response format expected from the LLM" - ) - - @property - def supports_anthropic_computer_use(self): - return "claude-3-5-sonnet-20241022" in self.model - - def create_completion( - self, - messages: List[dict], - system_prompt: str, - response_model: List[type[StructuredOutput]] | type[StructuredOutput], - ) -> CompletionResponse: - # Convert Message objects to dictionaries if needed - messages = [msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in messages] - - total_usage = Usage() - retry_count = 0 - - tools = [] - tool_choice = {"type": "any"} - - actions = [] - if not response_model: - tools = NOT_GIVEN - tool_choice = NOT_GIVEN - else: - if isinstance(response_model, list): - actions = response_model - elif response_model: - actions = [response_model] - - for action in actions: - if hasattr(action, "name") and action.name == "str_replace_editor": - tools.append({"name": "str_replace_editor", "type": "text_editor_20241022"}) - else: - schema = action.anthropic_schema() - - # Remove scratch pad field, use regular text block for thoughts - if "thoughts" in schema["input_schema"]["properties"]: - del schema["input_schema"]["properties"]["thoughts"] - - tools.append(schema) - - system_message = {"text": system_prompt, "type": "text"} - - anthropic_messages = anthropic_messages_pt( - model=self.model, - messages=messages, - llm_provider="anthropic", - ) - if "anthropic" in self.model: - anthropic_client = AnthropicBedrock() - betas = ["computer-use-2024-10-22"] # , "prompt-caching-2024-07-31"] - extra_headers = {} # "X-Amzn-Bedrock-explicitPromptCaching": "enabled"} - else: - anthropic_client = Anthropic() - extra_headers = {} - betas = ["computer-use-2024-10-22", "prompt-caching-2024-07-31"] - _inject_prompt_caching(anthropic_messages) - system_message["cache_control"] = {"type": "ephemeral"} - - retries = tenacity.Retrying( - retry=tenacity.retry_if_not_exception_type(anthropic.BadRequestError), - stop=tenacity.stop_after_attempt(3), - ) - - def _do_completion(): - nonlocal retry_count, total_usage - - completion_response = None - try: - completion_response = anthropic_client.beta.messages.create( - model=self.model, - max_tokens=self.max_tokens, - temperature=self.temperature, - system=[system_message], - tools=tools, - messages=anthropic_messages, - betas=betas, - extra_headers=extra_headers, - ) - - total_usage += Usage.from_completion_response(completion_response, self.model) - - def get_response_format(name: str): - if len(actions) == 1: - return actions[0] - else: - for check_action in actions: - if check_action.name == block.name: - return check_action - return None - - text = None - structured_outputs = [] - for block in completion_response.content: - if isinstance(block, ToolUseBlock) or isinstance(block, BetaToolUseBlock): - action = None - - tool_call_id = block.id - - if len(actions) == 1: - action = actions[0] - else: - for check_action in actions: - if check_action.name == block.name: - action = check_action - break - - if not action: - raise ValueError(f"Unknown action {block.name}") - - action_args = action.model_validate(block.input) - structured_outputs.append(action_args) - - elif isinstance(block, TextBlock) or isinstance(block, BetaTextBlock): - text = block.text - - else: - logger.warning(f"Unexpected block {block}]") - - completion = Completion.from_llm_completion( - input_messages=messages, - completion_response=completion_response, - model=self.model, - usage=total_usage, - retries=retry_count, - ) - - # Log summary of the response - action_names = [output.__class__.__name__ for output in structured_outputs] - has_text = bool(text and text.strip()) - if action_names: - logger.info(f"Completion response summary - Actions: {action_names}, Has text: {has_text}") - else: - logger.info(f"Completion response summary - Text only: {text[:200]}...") - - return CompletionResponse( - structured_outputs=structured_outputs, - text_response=text, - completion=completion, - ) - - except ValidationError as e: - logger.warning( - f"Validation failed with error {e}. Response: {json.dumps(completion_response.model_dump() if completion_response else None, indent=2)}" - ) - messages.append( - { - "role": "assistant", - "content": [block.model_dump() for block in completion_response.content], - } - ) - messages.append( - { - "role": "user", - "content": [ - { - "tool_use_id": tool_call_id, - "content": f"\nThe response was invalid. Fix the errors: {e}\n", - "type": "tool_result", - } - ], - } - ) - retry_count += 1 - raise CompletionRejectError( - message=str(e), - last_completion=completion_response, - messages=messages, - ) from e - except Exception as e: - raise CompletionRuntimeError( - f"Failed to get completion response: {e}", - messages=messages, - last_completion=completion_response, - ) - - try: - return retries(_do_completion) - except tenacity.RetryError as e: - raise e.reraise() - - -def _inject_prompt_caching( - messages: list[Union["AnthropicMessagesUserMessageParam", "AnthopicMessagesAssistantMessageParam"]], -): - from anthropic.types.beta import BetaCacheControlEphemeralParam - - """ - Set cache breakpoints for the 3 most recent turns - one cache breakpoint is left for tools/system prompt, to be shared across sessions - """ - - breakpoints_remaining = 3 - for message in reversed(messages): - # message["role"] == "user" and - if isinstance(content := message["content"], list): - if breakpoints_remaining: - breakpoints_remaining -= 1 - content[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"}) - else: - content[-1].pop("cache_control", None) - # we'll only every have one extra turn per loop - break diff --git a/moatless/completion/base.py b/moatless/completion/base.py index 9b7beff3..d93e8f1c 100644 --- a/moatless/completion/base.py +++ b/moatless/completion/base.py @@ -164,7 +164,7 @@ def initialize( def initialized(self) -> bool: return self._initialized - def create_completion( + async def create_completion( self, messages: List[dict], ) -> CompletionResponse: @@ -174,7 +174,8 @@ def create_completion( ) prepared_messages = self._prepare_messages(messages, self._system_prompt) - return self._create_completion_with_retries(messages=prepared_messages) + return await self._create_completion_with_retries(messages=prepared_messages) + def _prepare_system_prompt( self, @@ -221,7 +222,7 @@ def _get_completion_params(self, schema: type[ResponseSchema]) -> dict[str, Unio """ return {} - def _create_completion_with_retries( + async def _create_completion_with_retries( self, messages: List[dict], ) -> CompletionResponse: @@ -239,12 +240,12 @@ def _create_completion_with_retries( f"Retrying litellm completion after error: {retry_state.outcome.exception()}" ), ) - def _do_completion_with_validation(): + async def _do_completion_with_validation(): nonlocal retry_count, accumulated_usage, completion_response retry_count += 1 # Execute completion and get raw response - completion_response = self._execute_completion(messages) + completion_response = await self._execute_completion(messages) # Track usage from this attempt regardless of validation outcome usage = Usage.from_completion_response(completion_response, self.model) @@ -295,7 +296,7 @@ def _do_completion_with_validation(): ) try: - return _do_completion_with_validation() + return await _do_completion_with_validation() except CompletionRetryError as e: logger.warning( f"Completion failed after {retry_count} retries. Exception: {e}. Completion response: {completion_response}" @@ -317,7 +318,7 @@ def _do_completion_with_validation(): accumulated_usage=accumulated_usage, ) from e - def _execute_completion( + async def _execute_completion( self, messages: List[Dict[str, str]], ): @@ -354,7 +355,7 @@ def _execute_completion( f"Rate limited by provider, retrying in {retry_state.next_action.sleep} seconds" ), ) - def _do_completion_with_rate_limit_retry(): + async def _do_completion_with_rate_limit_retry(): try: if "claude-3-5" in self.model: self._inject_prompt_caching(messages) @@ -364,7 +365,7 @@ def _do_completion_with_rate_limit_retry(): if self.model_api_key: params["api_key"] = self.model_api_key - return litellm.completion( + return await litellm.acompletion( model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, @@ -399,7 +400,7 @@ def _do_completion_with_rate_limit_retry(): raise CompletionRuntimeError(message=str(e), messages=messages) from e - return _do_completion_with_rate_limit_retry() + return await _do_completion_with_rate_limit_retry() def _get_schema_names(self): return [schema.__name__ for schema in self._response_schema] if self._response_schema else ["None"] diff --git a/moatless/events.py b/moatless/events.py index 7fefed04..303f9793 100644 --- a/moatless/events.py +++ b/moatless/events.py @@ -71,32 +71,19 @@ def get_instance(cls) -> "EventBus": return cls._instance def subscribe(self, callback: Callable): + logger.info(f"Subscribing to event: {callback.__name__}") self._subscribers.append(callback) + logger.info(f"Subscribed to {len(self._subscribers)} events") - def publish(self, run_id: str, event: BaseEvent): + async def publish(self, run_id: str, event: BaseEvent): """Publish event, handling both sync and async subscribers""" - # Handle sync subscribers immediately - logger.info(f"Publishing event: {event.event_type}") - - for callback in self._subscribers: - if not asyncio.iscoroutinefunction(callback): - try: - callback(run_id, event) - except Exception as e: - logger.error(f"Error in event subscriber: {e}") - - try: - loop = asyncio.get_running_loop() - loop.create_task(self.publish_async(run_id, event)) - except RuntimeError: - # No event loop running - skip async subscribers - pass - - async def publish_async(self, run_id: str, event: BaseEvent): - """Asynchronous publish for async subscribers""" - for callback in self._subscribers: - if asyncio.iscoroutinefunction(callback): - await callback(run_id, event) + logger.info(f"Publishing event: {event.event_type} to {len(self._subscribers)} subscribers") + await asyncio.gather(*[self._run_async_callback(callback, run_id, event) for callback in self._subscribers]) + + async def _run_async_callback(self, callback: Callable, run_id: str, event: BaseEvent): + """Helper method to run a single async callback""" + logger.info(f"Running async callback: {callback.__name__}") + await callback(run_id, event) event_bus = EventBus.get_instance() \ No newline at end of file diff --git a/moatless/loop.py b/moatless/loop.py index a4f2d83a..3bad2868 100644 --- a/moatless/loop.py +++ b/moatless/loop.py @@ -24,7 +24,7 @@ class AgenticLoop(AgenticSystem): model_config = ConfigDict(arbitrary_types_allowed=True) - def _run(self) -> Node: + async def _run(self) -> Node: """Run the agentic loop until completion or max iterations.""" current_node = self.root.get_all_nodes()[-1] @@ -49,7 +49,7 @@ def _run(self) -> Node: try: current_node = self._create_next_node(current_node) - self.agent.run(current_node) + await self.agent.run(current_node) self.maybe_persist() self.log(logger.info, generate_ascii_tree(self.root, current_node)) except RejectError as e: diff --git a/moatless/runner.py b/moatless/runner.py index 44fa1da6..4259600f 100644 --- a/moatless/runner.py +++ b/moatless/runner.py @@ -1,5 +1,6 @@ import asyncio import logging +import time from typing import Callable, Dict, List, Tuple from moatless.agentic_system import AgenticSystem @@ -20,11 +21,11 @@ def get_instance(cls) -> "AgenticRunner": cls._instance = cls() return cls._instance - def start(self, agentic_system: AgenticSystem) -> str: + async def start(self, agentic_system: AgenticSystem) -> str: run_id = agentic_system.run_id async def trajectory_wrapper() -> dict: - result = await asyncio.to_thread(agentic_system.run) + result = await agentic_system.run() # Clean up when done. self.active_runs.pop(run_id, None) return result @@ -34,23 +35,12 @@ async def trajectory_wrapper() -> dict: self.active_runs[run_id] = (agentic_system, task_obj) return run_id - def get_run(self, run_id: str) -> AgenticSystem | None: + async def get_run(self, run_id: str) -> AgenticSystem | None: + start_time = time.time() entry = self.active_runs.get(run_id) if entry is None: return None + logger.info(f"Run {run_id} took {time.time() - start_time} seconds to get run from runner") return entry[0] - def get_status(self, run_id: str) -> dict: - entry = self.active_runs.get(run_id) - if entry is None: - return {"error": "Run not found or finished."} - agentic_system, task_obj = entry - status = agentic_system.get_status() - if task_obj.done(): - status["result"] = task_obj.result() - status["status"] = "finished" - else: - status["status"] = "running" - return status - agentic_runner = AgenticRunner.get_instance() \ No newline at end of file diff --git a/moatless/search_tree.py b/moatless/search_tree.py index 77333652..19902deb 100644 --- a/moatless/search_tree.py +++ b/moatless/search_tree.py @@ -131,7 +131,7 @@ def model_validate( instance = super().model_validate(obj) return instance - def _run(self) -> Node: + async def _run(self) -> Node: """Run the search tree algorithm with the given node.""" if not self.root: raise ValueError("No node provided to run") diff --git a/moatless/validation/code_flow_validation.py b/moatless/validation/code_flow_validation.py index 830a5cd7..2b518239 100644 --- a/moatless/validation/code_flow_validation.py +++ b/moatless/validation/code_flow_validation.py @@ -49,7 +49,7 @@ def setup_run_directory(self, run_dir: str) -> dict: return dirs - def start_code_loop(self, + async def start_code_loop(self, run_id: str, agent_id: str, model_id: str, @@ -103,5 +103,5 @@ def start_code_loop(self, } ) - agentic_runner.start(loop) + await agentic_runner.start(loop) return run_id