Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions agents/impl_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class GenerationAttempt:
@dataclasses.dataclass(frozen=True)
class ImplGenerationRequest:
interface_str: str
test_str: str
prior_attempts: list[GenerationAttempt] = dataclasses.field(default_factory=list)


Expand All @@ -30,7 +29,7 @@ def __init__(self, model_str: str = ''):
'The implementation should be fast, memory-efficient, and as simple as possible while meeting all requirements.')
)

def _make_initial_prompt(self, python_interface: str, test_str: str) -> str:
def _make_initial_prompt(self, python_interface: str, test_code: str) -> str:
example_impl = textwrap.dedent('''
from my_interface import MyInterface

Expand All @@ -50,13 +49,13 @@ def foo(self) -> str:
'The code you will generate is *not* an abstract class, and does *not* have any `@abstractmethod` annotations. '
'The interface itself already exists in the same directory, so do not add it here. '
'The test suite that should pass looks like this:\n\n'
f'{utils.wrap_code_in_markdown(test_str)}'
f'{utils.wrap_code_in_markdown(test_code)}'
'An example implementation might look something like this:\n\n'
f'{utils.wrap_code_in_markdown(example_impl)}'
)

def _make_improvement_prompt(
self, python_interface: str, test_str: str, prior_attempts: list[GenerationAttempt] = []
self, python_interface: str, test_code: str, prior_attempts: list[GenerationAttempt] = []
) -> str:
# This variable is currently unused, but kept for possible future use
_ = textwrap.dedent('''
Expand All @@ -79,7 +78,7 @@ def foo(self) -> str:
'Your instructions were to make sure the name of the class ends with "Impl", and it inherits from the interface. '
'You can assume the interface exists the same directory as the implementation being generated. '
'The test suite that was run looks like this:\n\n'
f'{utils.wrap_code_in_markdown(test_str)}'
f'{utils.wrap_code_in_markdown(test_code)}'
'When the tests were run, the following output indicates some problems:'
f'```\n{prior_attempts[-1].errors}\n```\n\n'
'Please generate a new implementation according to the same instructions, '
Expand Down
6 changes: 4 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def main(example_name: str):
return

test_str = test_gen.str_to_file(request, test_path)
impl_str = impl_gen.str_to_file(impl_generator.ImplGenerationRequest(example.code, test_str), impl_path)
with open(test_path, 'r') as test_file:
test_code = test_file.read()
impl_str = impl_gen.str_to_file(impl_generator.ImplGenerationRequest(example.code, prior_attempts=[]), impl_path)
tests_pass, test_output = run_tests(project_dir)

if not tests_pass:
Expand All @@ -194,7 +196,7 @@ def main(example_name: str):
while not tests_pass and num_impl_rounds > 0:
print('Tests did not pass, trying another round of impl generation.')
impl_attempts.append(impl_generator.GenerationAttempt(impl_str, test_output))
impl_request = impl_generator.ImplGenerationRequest(example.code, test_str, impl_attempts)
impl_request = impl_generator.ImplGenerationRequest(example.code, impl_attempts)
impl_str = impl_gen.str_to_file(impl_request, impl_path)
tests_pass, test_output = run_tests(project_dir)
num_impl_rounds -= 1
Expand Down
12 changes: 6 additions & 6 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,27 @@ def test_iteration_loop(self, iface_path: str) -> None:

test_path = iface_path.replace('.py', '_test.py')
request = test_generator.TestGenerationRequest(interface_str=iface_str)
test_str = self._test_gen.str_to_file(request, test_path)
test_code = self._test_gen.str_to_file(request, test_path)
if compilation_error := try_compile_file(test_path):
attempt = test_generator.GenerationAttempt(code=test_str, errors=compilation_error)
attempt = test_generator.GenerationAttempt(code=test_code, errors=compilation_error)
request = test_generator.TestGenerationRequest(interface_str=iface_str, prior_attempts=[attempt])
test_str = self._test_gen.str_to_file(request, test_path)
test_code = self._test_gen.str_to_file(request, test_path)

def impl_iteration_loop(self, iface_path: str, test_path: str) -> None:
"""Iterates on impl creation, returning the finished file."""
with open(iface_path, 'r') as iface_file:
iface_str = iface_file.read()
with open(test_path, 'r') as test_file:
test_str = test_file.read()
test_code = test_file.read()

impl_path = iface_path.replace('.py', '_impl.py')
request = impl_generator.ImplGenerationRequest(iface_str, test_str)
request = impl_generator.ImplGenerationRequest(iface_str, prior_attempts=[])
impl_str = self._impl_gen.str_to_file(request, impl_path)
tests_pass, test_output = run_tests(self._working_dir)
attempts = []
while not tests_pass:
attempts.append(impl_generator.GenerationAttempt(code=impl_str, errors=test_output))
request = impl_generator.ImplGenerationRequest(iface_str, test_str, attempts)
request = impl_generator.ImplGenerationRequest(iface_str, attempts)
impl_str = self._impl_gen.str_to_file(request, impl_path)
tests_pass, test_output = run_tests(self._working_dir)

Expand Down