Skip to content

Commit

Permalink
Run ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Jul 30, 2024
1 parent 6d4bb5c commit 0d93162
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
- name: Install dependencies
run: poetry install
- name: Run tests
run: poetry run pytest
run: poetry run pytest
3 changes: 2 additions & 1 deletion moatless/benchmark/claude_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def evaluate_plan(previous_trajectory_dir: Optional[str] = None):
for instance_id in df.index:
print(df.loc[instance_id, "instance_id"], df.loc[instance_id, "planned"])


if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
Expand All @@ -350,7 +351,7 @@ def evaluate_plan(previous_trajectory_dir: Optional[str] = None):
# evaluate_search_and_identify()
evaluate_search_and_code(
1,
"/home/albert/repos/albert/moatless/evaluations/20240623_moatless_claude-3.5-sonnet/trajs",
"/home/albert/repos/albert/moatless/evaluations/20240623_moatless_claude-3.5-sonnet/trajs",
retry_state="PlanToCode",
)
# evaluate_search_and_code()
Expand Down
19 changes: 15 additions & 4 deletions moatless/benchmark/swebench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,14 @@ def generate_md_report(trajectory: dict, instance: dict):
return markdown


def setup_swebench_repo(instance_data: dict | None = None, instance_id: str = None, repo_base_dir: str | None = None) -> str:
assert instance_data or instance_id, "Either instance_data or instance_id must be provided"
def setup_swebench_repo(
instance_data: dict | None = None,
instance_id: str = None,
repo_base_dir: str | None = None,
) -> str:
assert (
instance_data or instance_id
), "Either instance_data or instance_id must be provided"
if not instance_data:
instance_data = load_instance(instance_id)

Expand All @@ -297,7 +303,7 @@ def create_workspace(
instance: dict | None = None,
instance_id: str | None = None,
repo_base_dir: str | None = None,
index_store_dir: str | None = None
index_store_dir: str | None = None,
):
"""
Create a workspace for the given SWE-bench instance.
Expand All @@ -321,4 +327,9 @@ def create_workspace(
persist_dir = os.path.join(
index_store_dir, get_repo_dir_name(instance["instance_id"])
)
return Workspace.from_dirs(git_repo_url=repo_url, commit=instance["base_commit"], repo_dir=repo_dir, index_dir=persist_dir)
return Workspace.from_dirs(
git_repo_url=repo_url,
commit=instance["base_commit"],
repo_dir=repo_dir,
index_dir=persist_dir,
)
6 changes: 5 additions & 1 deletion moatless/codeblocks/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def _extract_node_type(self, query: str):
return None

def _build_queries(self, query_file: str):
with resources.files("moatless.codeblocks.parser.queries").joinpath(query_file).open() as file:
with (
resources.files("moatless.codeblocks.parser.queries")
.joinpath(query_file)
.open() as file
):
query_list = file.read().strip().split("\n\n")
parsed_queries = []
for i, query in enumerate(query_list):
Expand Down
2 changes: 1 addition & 1 deletion moatless/codeblocks/parser/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def post_process(self, codeblock: CodeBlock):
def is_outcommented_code(self, comment):
return comment.startswith("# ...") or any(
keyword in comment.lower() for keyword in commented_out_keywords
)
)
4 changes: 2 additions & 2 deletions moatless/edit/review.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class ApplyChanges(ActionRequest):


class ReviewCode(AgenticState):

message: Optional[str] = Field(
None,
description="Message to the coder",
Expand Down Expand Up @@ -130,7 +129,8 @@ class ReviewCode(AgenticState):
)

finish_on_no_errors: bool = Field(
False, description="Whether to finish the task if no verification errors are found."
False,
description="Whether to finish the task if no verification errors are found.",
)

_verification_errors: List[VerificationError] = PrivateAttr(default_factory=list)
Expand Down
40 changes: 22 additions & 18 deletions moatless/file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, **data):

def model_dump(self, **kwargs):
data = super().model_dump(**kwargs, exclude={"file"})
data['file_path'] = self.file.file_path
data["file_path"] = self.file.file_path
return data

@property
Expand Down Expand Up @@ -398,7 +398,6 @@ def expand_small_classes(self, max_tokens: int):


class FileContext(BaseModel):

_repo: FileRepository = PrivateAttr()
_file_context: Dict[str, ContextFile] = PrivateAttr(default_factory=dict)
_max_tokens: int = PrivateAttr(default=4000)
Expand All @@ -408,10 +407,10 @@ class FileContext(BaseModel):
def __init__(self, repo: FileRepository, **data):
super().__init__(**data)
self._repo = repo
if '_file_context' not in self.__dict__:
self.__dict__['_file_context'] = {}
if '_max_tokens' not in self.__dict__:
self.__dict__['_max_tokens'] = data.get('max_tokens', 4000)
if "_file_context" not in self.__dict__:
self.__dict__["_file_context"] = {}
if "_max_tokens" not in self.__dict__:
self.__dict__["_max_tokens"] = data.get("max_tokens", 4000)

@classmethod
def from_dir(cls, repo_dir: str, max_tokens: int = 4000):
Expand All @@ -423,7 +422,7 @@ def from_dir(cls, repo_dir: str, max_tokens: int = 4000):
def from_json(cls, repo_dir: str, json_data: str):
"""
Create a FileContext instance from JSON data.
:param repo_dir: The repository directory path.
:param json_data: A JSON string representing the FileContext data.
:return: A new FileContext instance.
Expand All @@ -434,11 +433,11 @@ def from_json(cls, repo_dir: str, json_data: str):
@classmethod
def from_dict(cls, repo_dir: str, data: Dict):
repo = FileRepository(repo_dir)
instance = cls(max_tokens=data.get('max_tokens', 4000), repo=repo)
for file_data in data.get('files', []):
file_path = file_data['file_path']
show_all_spans = file_data.get('show_all_spans', False)
spans = [ContextSpan(**span) for span in file_data.get('spans', [])]
instance = cls(max_tokens=data.get("max_tokens", 4000), repo=repo)
for file_data in data.get("files", []):
file_path = file_data["file_path"]
show_all_spans = file_data.get("show_all_spans", False)
spans = [ContextSpan(**span) for span in file_data.get("spans", [])]
instance._file_context[file_path] = ContextFile(
file=instance._repo.get_file(file_path),
spans=spans,
Expand All @@ -448,11 +447,14 @@ def from_dict(cls, repo_dir: str, data: Dict):
return instance

def model_dump(self, **kwargs):
if 'exclude_none' not in kwargs:
kwargs['exclude_none'] = True
if "exclude_none" not in kwargs:
kwargs["exclude_none"] = True

files = [file.model_dump(**kwargs) for file in self.__dict__['_file_context'].values()]
return {"max_tokens": self.__dict__['_max_tokens'], "files": files}
files = [
file.model_dump(**kwargs)
for file in self.__dict__["_file_context"].values()
]
return {"max_tokens": self.__dict__["_max_tokens"], "files": files}

def snapshot(self):
dict = self.model_dump()
Expand Down Expand Up @@ -525,7 +527,9 @@ def add_spans_to_context(
else:
logger.warning(f"Could not find file {file_path} in the repository")

def add_span_to_context(self, file_path: str, span_id: str, tokens: Optional[int] = None):
def add_span_to_context(
self, file_path: str, span_id: str, tokens: Optional[int] = None
):
context_file = self.get_context_file(file_path)
if context_file:
context_file.add_span(span_id, tokens)
Expand Down Expand Up @@ -714,7 +718,7 @@ def reset(self):
self._file_context = {}

def strip_line_breaks_only(self, text):
return text.lstrip('\n\r').rstrip('\n\r')
return text.lstrip("\n\r").rstrip("\n\r")

def create_prompt(
self,
Expand Down
5 changes: 4 additions & 1 deletion moatless/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def run(
raise Exception("Loop is already running.")

self._trajectory = Trajectory(
"AgenticLoop", initial_message=message, persist_path=self._trajectory_path, workspace=self.workspace.dict()
"AgenticLoop",
initial_message=message,
persist_path=self._trajectory_path,
workspace=self.workspace.dict(),
)

self.transition_to(self._transitions.initial_state(**input_data or {}))
Expand Down
10 changes: 3 additions & 7 deletions moatless/repository/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
from dataclasses import dataclass
from typing import Optional

from pydantic import BaseModel, ConfigDict

Expand All @@ -30,7 +29,7 @@ class CodeFile(BaseModel):
module: Module | None = None
dirty: bool = False

model_config = ConfigDict(exclude={'module', 'dirty'})
model_config = ConfigDict(exclude={"module", "dirty"})

@classmethod
def from_file(cls, repo_path: str, file_path: str):
Expand Down Expand Up @@ -177,10 +176,7 @@ def __init__(self, repo_path: str):
self._files: dict[str, CodeFile] = {}

def dict(self):
return {
"type": "file",
"path": self._repo_path
}
return {"type": "file", "path": self._repo_path}

def snapshot(self) -> dict:
return {}
Expand Down Expand Up @@ -298,4 +294,4 @@ def do_diff(file_path: str, original_content: str, updated_content: str) -> str
tofile=file_path,
lineterm="\n",
)
)
)
10 changes: 7 additions & 3 deletions moatless/repository/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@

logger = logging.getLogger(__name__)

class GitRepository(FileRepository):

class GitRepository(FileRepository):
def __init__(self, repo_path: str, repo_url: str | None, commit: str | None = None):
super().__init__(repo_path)
self._repo_path = repo_path
self._repo_url = repo_url
self._repo = Repo(path=repo_path)
if not self._repo.heads:
raise Exception("Git repository has no heads, you need to do an initial commit.")
raise Exception(
"Git repository has no heads, you need to do an initial commit."
)

# TODO: Check if current branch is mainline

Expand All @@ -26,7 +28,9 @@ def __init__(self, repo_path: str, repo_url: str | None, commit: str | None = No

@classmethod
def from_repo(cls, repo_url: str, repo_path: str, commit: str | None = None):
logger.info(f"Clone GitRepository from {repo_url} with commit {commit} to {repo_path} ")
logger.info(
f"Clone GitRepository from {repo_url} with commit {commit} to {repo_path} "
)

maybe_clone(repo_url, repo_path)

Expand Down
9 changes: 5 additions & 4 deletions moatless/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class AgenticState(ABC, BaseModel):
default=False,
description="The message history from previous initations should be included in the completion request",
)
model: str | None = Field(default=None, description="The model to use for completion")
model: str | None = Field(
default=None, description="The model to use for completion"
)
temperature: float = Field(0.0, description="The temperature to use for completion")
max_tokens: int = Field(
1000, description="The maximum number of tokens to generate"
Expand Down Expand Up @@ -122,7 +124,7 @@ def stop_words(self) -> list[str] | None:

def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
data['name'] = self.name
data["name"] = self.name
return data


Expand Down Expand Up @@ -153,5 +155,4 @@ def __init__(self, message: str, **kwargs):

class Pending(NoopState):
def __init__(self, **data):

super().__init__(**data)
super().__init__(**data)
14 changes: 7 additions & 7 deletions moatless/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def name(self):
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
if self.state:
data['state']['name'] = self.state.name
data['actions'] = [action.model_dump(**kwargs) for action in self.actions]
data["state"]["name"] = self.state.name

data["actions"] = [action.model_dump(**kwargs) for action in self.actions]

return data


Expand All @@ -52,7 +52,7 @@ def __init__(
name: str,
initial_message: str | None = None,
persist_path: str | None = None,
workspace: dict | None = None
workspace: dict | None = None,
):
self._name = name
self._persist_path = persist_path
Expand Down Expand Up @@ -139,7 +139,7 @@ def to_dict(self, **kwargs):
"initial_message": self._initial_message,
"transitions": transition_dicts,
"info": self._info,
"dummy_field": None # Add this line
"dummy_field": None, # Add this line
}

def total_cost(self):
Expand All @@ -163,4 +163,4 @@ def persist(self, file_path: str):
indent=2,
default=to_jsonable_python,
)
)
)
4 changes: 3 additions & 1 deletion moatless/utils/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ def setup_github_repo(repo: str, base_commit: str, base_dir: str = "/tmp/repos")
repo_name = get_repo_dir_name(repo)
repo_url = f"https://github.com/{repo}.git"
path = f"{base_dir}/{repo_name}"
logger.info(f"Clone Github repo {repo_url} to {path} and checkout commit {base_commit}")
logger.info(
f"Clone Github repo {repo_url} to {path} and checkout commit {base_commit}"
)
if not os.path.exists(path):
os.makedirs(path)
logger.info(f"Directory '{path}' was created.")
Expand Down
9 changes: 5 additions & 4 deletions moatless/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class Workspace:

def __init__(
self,
file_repo: FileRepository,
Expand Down Expand Up @@ -50,7 +49,9 @@ def from_dirs(
**kwargs,
):
if git_repo_url:
file_repo = GitRepository.from_repo(repo_url=git_repo_url, repo_path=repo_dir, commit=commit)
file_repo = GitRepository.from_repo(
repo_url=git_repo_url, repo_path=repo_dir, commit=commit
)
elif repo_dir:
file_repo = FileRepository(repo_dir)
else:
Expand Down Expand Up @@ -83,13 +84,13 @@ def from_dirs(
def dict(self):
return {
"repository": self.file_repo.dict(),
"file_context": self.file_context.model_dump()
"file_context": self.file_context.model_dump(),
}

def snapshot(self):
return {
"repository": self.file_repo.snapshot(),
"file_context": self.file_context.snapshot()
"file_context": self.file_context.snapshot(),
}

def create_file_context(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from datetime import datetime

import logging
from pydantic import Field

Expand Down

0 comments on commit 0d93162

Please sign in to comment.