diff --git a/agents/impl_generator.py b/agents/impl_generator.py index fa293d9..33ce9f5 100644 --- a/agents/impl_generator.py +++ b/agents/impl_generator.py @@ -15,7 +15,6 @@ class GenerationAttempt: @dataclasses.dataclass(frozen=True) class ImplGenerationRequest: interface_str: str - test_str: str prior_attempts: list[GenerationAttempt] = dataclasses.field(default_factory=list) @@ -30,7 +29,7 @@ def __init__(self, model_str: str = ''): 'The implementation should be fast, memory-efficient, and as simple as possible while meeting all requirements.') ) - def _make_initial_prompt(self, python_interface: str, test_str: str) -> str: + def _make_initial_prompt(self, python_interface: str, test_code: str) -> str: example_impl = textwrap.dedent(''' from my_interface import MyInterface @@ -50,13 +49,13 @@ def foo(self) -> str: 'The code you will generate is *not* an abstract class, and does *not* have any `@abstractmethod` annotations. ' 'The interface itself already exists in the same directory, so do not add it here. ' 'The test suite that should pass looks like this:\n\n' - f'{utils.wrap_code_in_markdown(test_str)}' + f'{utils.wrap_code_in_markdown(test_code)}' 'An example implementation might look something like this:\n\n' f'{utils.wrap_code_in_markdown(example_impl)}' ) def _make_improvement_prompt( - self, python_interface: str, test_str: str, prior_attempts: list[GenerationAttempt] = [] + self, python_interface: str, test_code: str, prior_attempts: list[GenerationAttempt] = [] ) -> str: # This variable is currently unused, but kept for possible future use _ = textwrap.dedent(''' @@ -79,7 +78,7 @@ def foo(self) -> str: 'Your instructions were to make sure the name of the class ends with "Impl", and it inherits from the interface. ' 'You can assume the interface exists the same directory as the implementation being generated. ' 'The test suite that was run looks like this:\n\n' - f'{utils.wrap_code_in_markdown(test_str)}' + f'{utils.wrap_code_in_markdown(test_code)}' 'When the tests were run, the following output indicates some problems:' f'```\n{prior_attempts[-1].errors}\n```\n\n' 'Please generate a new implementation according to the same instructions, ' diff --git a/demo.py b/demo.py index 977a3d1..617e2d6 100644 --- a/demo.py +++ b/demo.py @@ -177,7 +177,9 @@ def main(example_name: str): return test_str = test_gen.str_to_file(request, test_path) - impl_str = impl_gen.str_to_file(impl_generator.ImplGenerationRequest(example.code, test_str), impl_path) + with open(test_path, 'r') as test_file: + test_code = test_file.read() + impl_str = impl_gen.str_to_file(impl_generator.ImplGenerationRequest(example.code, prior_attempts=[]), impl_path) tests_pass, test_output = run_tests(project_dir) if not tests_pass: @@ -194,7 +196,7 @@ def main(example_name: str): while not tests_pass and num_impl_rounds > 0: print('Tests did not pass, trying another round of impl generation.') impl_attempts.append(impl_generator.GenerationAttempt(impl_str, test_output)) - impl_request = impl_generator.ImplGenerationRequest(example.code, test_str, impl_attempts) + impl_request = impl_generator.ImplGenerationRequest(example.code, impl_attempts) impl_str = impl_gen.str_to_file(impl_request, impl_path) tests_pass, test_output = run_tests(project_dir) num_impl_rounds -= 1 diff --git a/start.py b/start.py index 01bf85a..2e7de09 100644 --- a/start.py +++ b/start.py @@ -83,27 +83,27 @@ def test_iteration_loop(self, iface_path: str) -> None: test_path = iface_path.replace('.py', '_test.py') request = test_generator.TestGenerationRequest(interface_str=iface_str) - test_str = self._test_gen.str_to_file(request, test_path) + test_code = self._test_gen.str_to_file(request, test_path) if compilation_error := try_compile_file(test_path): - attempt = test_generator.GenerationAttempt(code=test_str, errors=compilation_error) + attempt = test_generator.GenerationAttempt(code=test_code, errors=compilation_error) request = test_generator.TestGenerationRequest(interface_str=iface_str, prior_attempts=[attempt]) - test_str = self._test_gen.str_to_file(request, test_path) + test_code = self._test_gen.str_to_file(request, test_path) def impl_iteration_loop(self, iface_path: str, test_path: str) -> None: """Iterates on impl creation, returning the finished file.""" with open(iface_path, 'r') as iface_file: iface_str = iface_file.read() with open(test_path, 'r') as test_file: - test_str = test_file.read() + test_code = test_file.read() impl_path = iface_path.replace('.py', '_impl.py') - request = impl_generator.ImplGenerationRequest(iface_str, test_str) + request = impl_generator.ImplGenerationRequest(iface_str, prior_attempts=[]) impl_str = self._impl_gen.str_to_file(request, impl_path) tests_pass, test_output = run_tests(self._working_dir) attempts = [] while not tests_pass: attempts.append(impl_generator.GenerationAttempt(code=impl_str, errors=test_output)) - request = impl_generator.ImplGenerationRequest(iface_str, test_str, attempts) + request = impl_generator.ImplGenerationRequest(iface_str, attempts) impl_str = self._impl_gen.str_to_file(request, impl_path) tests_pass, test_output = run_tests(self._working_dir)