-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmllm_prompt.py
310 lines (258 loc) · 12.7 KB
/
mllm_prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import mllm_client, pathlib, concurrent.futures, os, subprocess, json
from pydantic import BaseModel
from pathlib import Path
import argparse
import yaml, logging, shortuuid, sqlite3, reactivedb
from mllm_types import EnvConfig, GenerationOutput, GenerationEnv
from mllm_client import Message
import mllm_sandbox, time, mllm_state
from mllm_state import safe_child
import tempfile
_pool = None
def thread_pool() -> concurrent.futures.ThreadPoolExecutor:
global _pool
if _pool is None:
_pool = concurrent.futures.ThreadPoolExecutor(20)
return _pool
def llm_code_block_with_retry(messages: list['Message'], model, prediction=None):
messages = list(dict(m) for m in messages) # type: ignore
messages.append({'role': 'assistant', 'content': '```' })
while True:
assert messages[-1]['role'] == 'assistant', messages
response = mllm_client.llm_output(messages, model, prediction=prediction)
messages[-1]['content'] += ''.join(response)
if messages[-1]['content'].count('```') >= 2:
break
return extract_final_code(messages[-1]['content'])
def extract_final_code(llm_output):
final_code = []
in_code_block = False
for line in llm_output.split('\n'):
if line.strip().startswith('```'):
in_code_block = True
elif in_code_block:
if line.strip() == '```':
break
final_code.append(line)
return '\n'.join(final_code)
def display_file(root_path, path):
abs_path = safe_child(root_path, path)
content = f"###\npath: {path}\n###\n{abs_path.read_text()}\n###\n"
return content
def create_prompt(root_path: pathlib.Path,
files,
model,
system_prompt,
task_prompt,
env: GenerationEnv):
prompt = f"""I'm working on a software project and I'll ask you to make some changes.
{system_prompt}
General instructions:
- do not catch exceptions, unless there is a particular need
- do not add comments, unless there would be particularly helpful
"""
if env.build_successful is False:
prompt += f"\nNote: The build has failed during the last attempt.\nBuild output:\n{env.build_output}\n"
elif env.build_successful is True:
prompt += "\nNote: The build was successful during the last attempt.\n"
prompt += '\n\nProject files will now follow.\n'
assert files, 'empty files!'
for cf in files:
prompt += display_file(root_path, cf)
prompt += f'''
Here is your task:
{task_prompt}
Now make a plan of the changes.
'''
# Show me modified parts of the code, but do not output any unmodified sections of code.
return prompt
def apply_changes(env: GenerationEnv, changed_files: dict[str, str]) -> str:
if not changed_files:
return env.base_revision
new_revision = shortuuid.uuid()
new_actual_root_path = env.config.state_root / new_revision
subprocess.check_call(['btrfs', 'subvolume', 'snapshot',
env.base_root_path, new_actual_root_path])
for name, data in changed_files.items():
file_path = safe_child(new_actual_root_path, name)
file_path.parent.mkdir(exist_ok=True, parents=True)
file_path.write_text(data)
subprocess.check_call(['btrfs', 'property', 'set', new_actual_root_path, 'ro', 'true'])
return new_revision
def report_result(db, env: GenerationEnv,
messages: list[mllm_client.Message],
new_revision: str, changed_files: list[str], meta):
out = GenerationOutput(
task_id=env.task_id,
generation_set_id=env.generation_set_id,
rebased_from_generation_output=None,
messages=messages,
meta=meta,
base_revision=env.base_revision,
tip_revision=new_revision,
changed_files=changed_files,
created=time.time(),
)
db.table(GenerationOutput).insert(out)
db.commit()
return out
def compute_changed_files(messages: list[mllm_client.Message], files: dict[str, str], diff_model) -> dict[str, str]:
if len(files) == 1:
which_files = set(files)
else:
which_files_resp = mllm_client.llm_one_function_call(
messages + [{'role': 'user',
'content': 'Please now call change_files function with filenames that you want to change, based on your previous message'}],
model='gpt-4o-mini',
functions=[
mllm_client.FunctionDefinition(
name='change_files',
description='Indicate all files that need to be modified.',
parameters=
{"type": "object",
"properties": {"files": {
"type": "array",
"items": {"type": "string", "description": "Filename to be modified"}
}}}
)
],
)
logging.info('which files %r', which_files_resp)
which_files = set(which_files_resp.arguments['files'])
def apply_unified_diff(original_content, diff_text):
with tempfile.NamedTemporaryFile(mode='w', delete=False) as orig_file, \
tempfile.NamedTemporaryFile(mode='w', delete=False) as diff_file:
try:
orig_file.write(original_content)
orig_file.flush()
diff_file.write(diff_text)
diff_file.flush()
result = subprocess.run(['patch', '--force', orig_file.name, diff_file.name],
capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Patch failed: {result.stderr}")
with open(orig_file.name, 'r') as f:
return f.read()
finally:
os.unlink(orig_file.name)
os.unlink(diff_file.name)
def diff_for_file(file):
prediction = files.get(file)
if diff_model.endswith(':diff'):
base_model = diff_model.rsplit(':', 1)[0]
patch_text = llm_code_block_with_retry(
messages=messages
+ [{'role': 'user',
'content': f'''Please output a unified diff (format: unified diff starting with --- and +++ headers)
that transforms file "{file}" into your new version. Only output the diff itself in markdown code block, no explanations.
Original content:
{prediction}'''}],
model=base_model,
)
#print(patch_text)
patched = apply_unified_diff(prediction, patch_text)
return file, patched
code = llm_code_block_with_retry(
messages=messages
+ [{'role': 'user',
'content': f'''
Please now output entire modified file "{file}" verbatim, in a single markdown code block.
DO NOT ABBREVIATE ANY PART OF THE FILE. Only make changes that you were asked to do.
DO NOT SAY "rest of the code remains the same", output entire file "{file}" with modifications, not just parts of it.
Don't worry that your output is very long.
DO NOT REMOVE EXISTING COMMENTS FROM CODE. But also don't add meaningless comments like "# Add this line" to the code, only
add ones that you'd like to have in the final codebase.
ONLY APPLY changes that are meant for file "{file}" - ignore planned changes to other files.
IMPORTANT: DO NOT REMOVE EXISTING COMMENTS FROM CODE.
Original content:
{prediction}
'''
}],
model=diff_model, prediction=prediction)
return file, code
new_files = list(thread_pool().map(diff_for_file, which_files))
return dict(new_files)
def run(db,
env: GenerationEnv,
model, diff_model, second_stage_model, max_phases: int):
initial_prompt = create_prompt(env.base_root_path, env.files, model, env.config.system_prompt, env.task_prompt, env)
messages: list[mllm_client.Message] = [{'role': 'user', 'content': initial_prompt}]
logging.info('initial reasoning')
messages += [{'role': 'assistant', 'content':
mllm_client.llm_output(messages, model=model)}]
logging.info('computing changes')
files_with_content = { name: safe_child(env.config.state_root / env.base_revision, name).read_text()
for name in env.files }
changed_files = compute_changed_files(messages, files_with_content, diff_model)
new_revision = apply_changes(env, changed_files)
out = report_result(db, env, messages,
new_revision=new_revision,
changed_files=list(changed_files.keys()),
meta=dict(model=model, diff_model=diff_model, stage='first'))
build_result, revision_after_build = mllm_sandbox.run_command_in_env(env, new_revision, ["bash", "-c", env.config.build_script], timeout=60)
out.build_successful = build_result.returncode == 0
out.build_output = (build_result.stdout + build_result.stderr).decode('utf8', 'replace')
db.table(GenerationOutput).set(out)
db.commit()
second_stage_models = [x.strip() for x in second_stage_model.split(',')]
if second_stage_models:
futures = [thread_pool().submit(run_all_second_stages, db, env, out, m, diff_model, max_phases)
for m in second_stage_models]
for f in futures:
f.result()
def run_all_second_stages(db, env, out, m, diff_model, max_phases):
for phase_number in range(max_phases):
out = run_second_stage(db, env, out, m, diff_model, phase_number)
if out.build_successful:
break
def run_second_stage(db, env, base_out, second_stage_model, diff_model, phase_number):
files_with_content = {n: safe_child(env.config.state_root / base_out.tip_revision, n).read_text()
for n in env.files}
diffs = mllm_state.generate_diff_for_output(env.config, base_out, html=False)
diff_report = 'Here are the changes I made:\n\n'
for f, d in diffs.items():
diff_report += f'File: {f}\n{d}\n\n'
diff_report += f'\nBuild {"succeeded" if base_out.build_successful else "failed"}.\n'
if not base_out.build_successful:
diff_report += f'Build output:\n{base_out.build_output}\n. Please fix the build errors.\n'
else:
diff_report += '\nIs there anything you want to improve? Find any mistakes in the diff and give concrete code fragments that you want to change. If the implementation is still not complete, it\'s time to complete it now.\n'
diff_report += '\n\nAfter you are done, restate all changes that need to happen, including all changes from previous stage.'
messages = base_out.messages[:]
messages.append({'role': 'user', 'content': diff_report})
response = mllm_client.llm_output(messages, model=second_stage_model)
messages.append({'role': 'assistant', 'content': response})
changed_files = compute_changed_files(messages, files_with_content, diff_model)
new_revision = apply_changes(env, changed_files)
out2 = report_result(
db, env, messages, new_revision, list(changed_files.keys()),
dict(model=base_out.meta['model'], second_stage_model=second_stage_model, diff_model=diff_model, stage='second',
phase_number=phase_number)
)
build_result, _ = mllm_sandbox.run_command_in_env(env, new_revision, ["bash", "-c", env.config.build_script], timeout=60)
out2.build_successful = build_result.returncode == 0
out2.build_output = (build_result.stdout + build_result.stderr).decode('utf8', 'replace')
db.table(GenerationOutput).set(out2)
db.commit()
return out2
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Process some files with LLM assistance.')
parser.add_argument('--state-root', type=pathlib.Path, required=True)
parser.add_argument('--id', type=str, required=True)
parser.add_argument('--max-phases', type=int, default=1, help='Max number of second-stage attempts')
parser.add_argument('--model', type=str, default='claude-sonnet', help='The model to use for planning.')
parser.add_argument('--diff-model', type=str, default='gpt-4o-mini', help='The model to use for generating diffs.')
parser.add_argument('--second-stage-model', type=str, default='gpt-4o-mini', help='The model to use for second stage of planning.')
args = parser.parse_args()
db = reactivedb.Db(args.state_root / "db.sqlite3")
generation_env_table = db.table(GenerationEnv)
env = generation_env_table.get(generation_set_id=args.id)
run(
db,
env=env,
model=args.model,
diff_model=args.diff_model,
second_stage_model=args.second_stage_model,
max_phases=args.max_phases,
)