|
| 1 | +import os |
| 2 | +import uuid |
| 3 | +import asyncio |
| 4 | +import copy |
| 5 | +from typing import Any, Callable, Optional |
| 6 | +from os import PathLike |
| 7 | +import time |
| 8 | +import json |
| 9 | +from pydantic import BaseModel, Field |
| 10 | +from . import prompts |
| 11 | + |
| 12 | +from futurehouse_client import FutureHouseClient |
| 13 | +from futurehouse_client.models import TaskRequest, RuntimeConfig |
| 14 | +from futurehouse_client.models.app import AuthType |
| 15 | + |
| 16 | + |
| 17 | +class StepConfig(BaseModel): |
| 18 | + """Configuration for a step in the pipeline.""" |
| 19 | + |
| 20 | + language: str = Field( |
| 21 | + default="PYTHON", description="Language for execution environment" |
| 22 | + ) |
| 23 | + max_steps: int = Field( |
| 24 | + default=30, description="Maximum number of steps for the agent" |
| 25 | + ) |
| 26 | + timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds") |
| 27 | + eval: bool = Field(default=True, description="Whether to use eval mode") |
| 28 | + |
| 29 | + |
| 30 | +class Step(BaseModel): |
| 31 | + """A step in the agent execution pipeline.""" |
| 32 | + |
| 33 | + name: str = Field( |
| 34 | + description="Name of the job to run (e.g. 'job-futurehouse-data-analysis-crow-high')" |
| 35 | + ) |
| 36 | + prompt_template: str = Field(description="Prompt template to use for the step") |
| 37 | + cot_prompt: bool = Field( |
| 38 | + default=False, description="Whether to augment the query with COT prompting" |
| 39 | + ) |
| 40 | + prompt_args: dict[str, Any] = Field( |
| 41 | + default_factory=dict, |
| 42 | + description="Keyword arguments to format the prompt template.", |
| 43 | + ) |
| 44 | + input_files: dict[str, str] = Field( |
| 45 | + default_factory=dict, description="Files to upload {'source_path': 'dest_name'}" |
| 46 | + ) |
| 47 | + output_files: dict[str, str] = Field( |
| 48 | + default_factory=dict, |
| 49 | + description="Files to download {'source_name': 'dest_path'}", |
| 50 | + ) |
| 51 | + step_id: str = Field( |
| 52 | + default_factory=lambda: str(uuid.uuid4())[:8], |
| 53 | + description="Small UID for the step", |
| 54 | + ) |
| 55 | + upload_id: Optional[str] = Field(default=None, description="Upload ID for GCS") |
| 56 | + parallel: int = Field(default=1, description="Number of parallel tasks to run") |
| 57 | + config: StepConfig = Field( |
| 58 | + default_factory=StepConfig, description="Configuration for the step" |
| 59 | + ) |
| 60 | + post_process: Optional[Callable[[dict[str, Any], str], None]] = Field( |
| 61 | + default=None, description="Function to run after step completion" |
| 62 | + ) |
| 63 | + prompt_generator: Optional[Callable[[], list[tuple[str, dict[str, Any]]]]] = Field( |
| 64 | + default=None, |
| 65 | + description="Function to generate prompts and args for parallel tasks based on previous results", |
| 66 | + ) |
| 67 | + |
| 68 | + def cot_prompting(self, query: str, language: str) -> str: |
| 69 | + """Apply chain-of-thought prompting to the query.""" |
| 70 | + guidelines = prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language) |
| 71 | + if language == "R": |
| 72 | + guidelines = prompts.R_SPECIFIC_GUIDELINES.format(language=language) |
| 73 | + return ( |
| 74 | + f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language)}\n" |
| 75 | + f"{guidelines}" |
| 76 | + f"Here is the research question to address:\n" |
| 77 | + f"<query>\n" |
| 78 | + f"{query}\n" |
| 79 | + f"</query>\n" |
| 80 | + ) |
| 81 | + |
| 82 | + def format_prompt(self) -> str: |
| 83 | + """Format the prompt template with the provided arguments.""" |
| 84 | + final_prompt = self.prompt_template.format(**self.prompt_args) |
| 85 | + if self.cot_prompt: |
| 86 | + final_prompt = self.cot_prompting(final_prompt, self.config.language) |
| 87 | + return final_prompt |
| 88 | + |
| 89 | + |
| 90 | +class Tortoise: |
| 91 | + """Runner for multi-step agent pipelines.""" |
| 92 | + |
| 93 | + def __init__(self, api_key: str): |
| 94 | + """Initialize the tortoise framework with FutureHouse API key.""" |
| 95 | + self.client = FutureHouseClient(auth_type=AuthType.API_KEY, api_key=api_key) |
| 96 | + self.steps: list[Step] = [] |
| 97 | + self.results: dict[str, Any] = {} |
| 98 | + |
| 99 | + def add_step(self, step: Step) -> None: |
| 100 | + """Add a step to the pipeline.""" |
| 101 | + self.steps.append(step) |
| 102 | + |
| 103 | + def save_results(self, output_dir: str | PathLike = "output") -> None: |
| 104 | + """Save the results to a JSON file.""" |
| 105 | + results_path = f"{output_dir}/results_{time.strftime('%Y%m%d_%H%M%S')}.json" |
| 106 | + print(f"Saving all results to {results_path}") |
| 107 | + try: |
| 108 | + os.makedirs(output_dir, exist_ok=True) |
| 109 | + serializable_results = {} |
| 110 | + for step_id, step_result in self.results.items(): |
| 111 | + serializable_results[step_id] = dict(step_result) |
| 112 | + |
| 113 | + with open(results_path, "w") as f: |
| 114 | + json.dump(serializable_results, f, indent=2) |
| 115 | + print(f"Results successfully saved to {results_path}") |
| 116 | + except Exception as e: |
| 117 | + print(f"Error saving results to {results_path}: {e}") |
| 118 | + |
| 119 | + def _create_task_requests( |
| 120 | + self, step: Step, runtime_config: RuntimeConfig |
| 121 | + ) -> list[TaskRequest]: |
| 122 | + """Create task requests with either identical or dynamic prompts. |
| 123 | +
|
| 124 | + Args: |
| 125 | + step: The step configuration |
| 126 | + runtime_config: The runtime configuration for the task |
| 127 | +
|
| 128 | + Returns: |
| 129 | + List of task requests to be executed |
| 130 | + """ |
| 131 | + task_requests = [] |
| 132 | + task_count = max(step.parallel, 1) |
| 133 | + |
| 134 | + if step.prompt_generator and task_count > 1: |
| 135 | + # Generate dynamic prompts based on previous results |
| 136 | + prompt_pairs = step.prompt_generator() |
| 137 | + # Create a task request for each generated prompt |
| 138 | + for prompt_text, prompt_args in prompt_pairs[ |
| 139 | + :task_count |
| 140 | + ]: # Limit to requested parallel count |
| 141 | + step_copy = copy.deepcopy(step) |
| 142 | + step_copy.prompt_template = prompt_text |
| 143 | + step_copy.prompt_args = prompt_args |
| 144 | + query = step_copy.format_prompt() |
| 145 | + task_requests.append( |
| 146 | + TaskRequest( |
| 147 | + name=step.name, |
| 148 | + query=query, |
| 149 | + runtime_config=runtime_config, |
| 150 | + ) |
| 151 | + ) |
| 152 | + else: |
| 153 | + # Default behavior: use the same prompt for all tasks |
| 154 | + query = step.format_prompt() |
| 155 | + task_requests = [ |
| 156 | + TaskRequest( |
| 157 | + name=step.name, |
| 158 | + query=query, |
| 159 | + runtime_config=runtime_config, |
| 160 | + ) |
| 161 | + ] * task_count |
| 162 | + |
| 163 | + return task_requests |
| 164 | + |
| 165 | + async def run_pipeline( |
| 166 | + self, output_dir: str | PathLike = "output" |
| 167 | + ) -> dict[str, Any]: |
| 168 | + """Run the entire pipeline of steps.""" |
| 169 | + os.makedirs(output_dir, exist_ok=True) |
| 170 | + |
| 171 | + for i, step in enumerate(self.steps): |
| 172 | + print(f"Running step {i + 1}/{len(self.steps)}: {step.name}") |
| 173 | + if not step.upload_id: |
| 174 | + step.upload_id = f"{step.name}_{step.step_id}" |
| 175 | + |
| 176 | + for source_path, dest_name in step.input_files.items(): |
| 177 | + print(f"Uploading file {source_path} as {dest_name}") |
| 178 | + self.client.upload_file( |
| 179 | + step.name, file_path=source_path, upload_id=step.upload_id |
| 180 | + ) |
| 181 | + |
| 182 | + if step.config: |
| 183 | + runtime_config = RuntimeConfig( |
| 184 | + max_steps=step.config.max_steps, |
| 185 | + upload_id=step.upload_id, |
| 186 | + environment_config={ |
| 187 | + "eval": step.config.eval, |
| 188 | + "language": step.config.language, |
| 189 | + }, |
| 190 | + ) |
| 191 | + else: |
| 192 | + runtime_config = None |
| 193 | + |
| 194 | + task_requests = self._create_task_requests(step, runtime_config) |
| 195 | + |
| 196 | + print( |
| 197 | + f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}" |
| 198 | + ) |
| 199 | + task_responses = await self.client.arun_tasks_until_done( |
| 200 | + task_requests, |
| 201 | + progress_bar=True, |
| 202 | + verbose=True, |
| 203 | + timeout=step.config.timeout, |
| 204 | + ) |
| 205 | + |
| 206 | + task_ids = [str(task.task_id) for task in task_responses] |
| 207 | + success_rate = sum( |
| 208 | + [task.status == "success" for task in task_responses] |
| 209 | + ) / len(task_responses) |
| 210 | + print(f"Task success rate: {success_rate * 100}%") |
| 211 | + |
| 212 | + self.results[step.step_id] = { |
| 213 | + "task_ids": task_ids, |
| 214 | + "task_responses": task_responses, |
| 215 | + "success_rate": success_rate, |
| 216 | + } |
| 217 | + |
| 218 | + os.makedirs(f"{output_dir}/{step.step_id}", exist_ok=True) |
| 219 | + |
| 220 | + for idx, task_id in enumerate(task_ids): |
| 221 | + for source_name, dest_path in step.output_files.items(): |
| 222 | + try: |
| 223 | + # Add index suffix only when there are multiple tasks |
| 224 | + path_suffix = f"_{idx}" if len(task_ids) > 1 else "" |
| 225 | + if "." in dest_path: |
| 226 | + base, ext = os.path.splitext(dest_path) |
| 227 | + dest_path_with_idx = f"{base}{path_suffix}{ext}" |
| 228 | + else: |
| 229 | + dest_path_with_idx = f"{dest_path}{path_suffix}" |
| 230 | + |
| 231 | + path = f"{output_dir}/{step.step_id}/{dest_path_with_idx}" |
| 232 | + os.makedirs( |
| 233 | + os.path.dirname(os.path.abspath(path)), exist_ok=True |
| 234 | + ) |
| 235 | + print(f"Downloading file {source_name} to {path}") |
| 236 | + self.client.download_file( |
| 237 | + step.name, |
| 238 | + trajectory_id=task_id, |
| 239 | + file_path=source_name, |
| 240 | + destination_path=path, |
| 241 | + ) |
| 242 | + except Exception as e: |
| 243 | + print( |
| 244 | + f"Error downloading {source_name} from task {task_id}: {e}" |
| 245 | + ) |
| 246 | + |
| 247 | + if step.post_process: |
| 248 | + print(f"Running post-processing for step {step.step_id}") |
| 249 | + step.post_process( |
| 250 | + self.results[step.step_id], f"{output_dir}/{step.step_id}" |
| 251 | + ) |
| 252 | + |
| 253 | + print(f"Completed step {i + 1}/{len(self.steps)}") |
| 254 | + |
| 255 | + self.save_results(output_dir) |
| 256 | + return self.results |
| 257 | + |
| 258 | + def run(self, output_dir: str | PathLike = "output") -> dict[str, Any]: |
| 259 | + """Synchronous version of run_pipeline.""" |
| 260 | + return asyncio.run(self.run_pipeline(output_dir)) |
0 commit comments