Skip to content

Commit b178f9d

Browse files
authored
Tortoise (#19)
1 parent 7002321 commit b178f9d

File tree

9 files changed

+833
-10
lines changed

9 files changed

+833
-10
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ cover/
8989

9090
# Jupyter Notebook
9191
.ipynb_checkpoints
92-
*.ipynb
92+
# *.ipynb
9393

9494
# IPython
9595
profile_default/

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
default_language_version:
33
python: python3
44
repos:
5+
- repo: https://github.com/psf/black-pre-commit-mirror
6+
rev: 25.1.0
7+
hooks:
8+
- id: black-jupyter
59
- repo: https://github.com/pre-commit/pre-commit-hooks
610
rev: v5.0.0
711
hooks:
@@ -92,3 +96,7 @@ repos:
9296
- types-tqdm
9397
- typing-extensions
9498
- wandb
99+
- repo: https://github.com/Yelp/detect-secrets
100+
rev: v1.5.0 # Use the latest version
101+
hooks:
102+
- id: detect-secrets

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies = [
1717
"google-auth==2.38.0",
1818
"google-cloud-storage==3.0.0",
1919
"google-cloud-secret-manager==2.23.0",
20-
"futurehouse-client==0.3.18.dev25",
20+
"futurehouse-client==0.3.18.dev80",
2121
"jupyter==1.1.1",
2222
"nbconvert==7.16.6",
2323
"notebook==7.3.2",

src/fhda/tortoise.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import os
2+
import uuid
3+
import asyncio
4+
import copy
5+
from typing import Any, Callable, Optional
6+
from os import PathLike
7+
import time
8+
import json
9+
from pydantic import BaseModel, Field
10+
from . import prompts
11+
12+
from futurehouse_client import FutureHouseClient
13+
from futurehouse_client.models import TaskRequest, RuntimeConfig
14+
from futurehouse_client.models.app import AuthType
15+
16+
17+
class StepConfig(BaseModel):
18+
"""Configuration for a step in the pipeline."""
19+
20+
language: str = Field(
21+
default="PYTHON", description="Language for execution environment"
22+
)
23+
max_steps: int = Field(
24+
default=30, description="Maximum number of steps for the agent"
25+
)
26+
timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds")
27+
eval: bool = Field(default=True, description="Whether to use eval mode")
28+
29+
30+
class Step(BaseModel):
31+
"""A step in the agent execution pipeline."""
32+
33+
name: str = Field(
34+
description="Name of the job to run (e.g. 'job-futurehouse-data-analysis-crow-high')"
35+
)
36+
prompt_template: str = Field(description="Prompt template to use for the step")
37+
cot_prompt: bool = Field(
38+
default=False, description="Whether to augment the query with COT prompting"
39+
)
40+
prompt_args: dict[str, Any] = Field(
41+
default_factory=dict,
42+
description="Keyword arguments to format the prompt template.",
43+
)
44+
input_files: dict[str, str] = Field(
45+
default_factory=dict, description="Files to upload {'source_path': 'dest_name'}"
46+
)
47+
output_files: dict[str, str] = Field(
48+
default_factory=dict,
49+
description="Files to download {'source_name': 'dest_path'}",
50+
)
51+
step_id: str = Field(
52+
default_factory=lambda: str(uuid.uuid4())[:8],
53+
description="Small UID for the step",
54+
)
55+
upload_id: Optional[str] = Field(default=None, description="Upload ID for GCS")
56+
parallel: int = Field(default=1, description="Number of parallel tasks to run")
57+
config: StepConfig = Field(
58+
default_factory=StepConfig, description="Configuration for the step"
59+
)
60+
post_process: Optional[Callable[[dict[str, Any], str], None]] = Field(
61+
default=None, description="Function to run after step completion"
62+
)
63+
prompt_generator: Optional[Callable[[], list[tuple[str, dict[str, Any]]]]] = Field(
64+
default=None,
65+
description="Function to generate prompts and args for parallel tasks based on previous results",
66+
)
67+
68+
def cot_prompting(self, query: str, language: str) -> str:
69+
"""Apply chain-of-thought prompting to the query."""
70+
guidelines = prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language)
71+
if language == "R":
72+
guidelines = prompts.R_SPECIFIC_GUIDELINES.format(language=language)
73+
return (
74+
f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language)}\n"
75+
f"{guidelines}"
76+
f"Here is the research question to address:\n"
77+
f"<query>\n"
78+
f"{query}\n"
79+
f"</query>\n"
80+
)
81+
82+
def format_prompt(self) -> str:
83+
"""Format the prompt template with the provided arguments."""
84+
final_prompt = self.prompt_template.format(**self.prompt_args)
85+
if self.cot_prompt:
86+
final_prompt = self.cot_prompting(final_prompt, self.config.language)
87+
return final_prompt
88+
89+
90+
class Tortoise:
91+
"""Runner for multi-step agent pipelines."""
92+
93+
def __init__(self, api_key: str):
94+
"""Initialize the tortoise framework with FutureHouse API key."""
95+
self.client = FutureHouseClient(auth_type=AuthType.API_KEY, api_key=api_key)
96+
self.steps: list[Step] = []
97+
self.results: dict[str, Any] = {}
98+
99+
def add_step(self, step: Step) -> None:
100+
"""Add a step to the pipeline."""
101+
self.steps.append(step)
102+
103+
def save_results(self, output_dir: str | PathLike = "output") -> None:
104+
"""Save the results to a JSON file."""
105+
results_path = f"{output_dir}/results_{time.strftime('%Y%m%d_%H%M%S')}.json"
106+
print(f"Saving all results to {results_path}")
107+
try:
108+
os.makedirs(output_dir, exist_ok=True)
109+
serializable_results = {}
110+
for step_id, step_result in self.results.items():
111+
serializable_results[step_id] = dict(step_result)
112+
113+
with open(results_path, "w") as f:
114+
json.dump(serializable_results, f, indent=2)
115+
print(f"Results successfully saved to {results_path}")
116+
except Exception as e:
117+
print(f"Error saving results to {results_path}: {e}")
118+
119+
def _create_task_requests(
120+
self, step: Step, runtime_config: RuntimeConfig
121+
) -> list[TaskRequest]:
122+
"""Create task requests with either identical or dynamic prompts.
123+
124+
Args:
125+
step: The step configuration
126+
runtime_config: The runtime configuration for the task
127+
128+
Returns:
129+
List of task requests to be executed
130+
"""
131+
task_requests = []
132+
task_count = max(step.parallel, 1)
133+
134+
if step.prompt_generator and task_count > 1:
135+
# Generate dynamic prompts based on previous results
136+
prompt_pairs = step.prompt_generator()
137+
# Create a task request for each generated prompt
138+
for prompt_text, prompt_args in prompt_pairs[
139+
:task_count
140+
]: # Limit to requested parallel count
141+
step_copy = copy.deepcopy(step)
142+
step_copy.prompt_template = prompt_text
143+
step_copy.prompt_args = prompt_args
144+
query = step_copy.format_prompt()
145+
task_requests.append(
146+
TaskRequest(
147+
name=step.name,
148+
query=query,
149+
runtime_config=runtime_config,
150+
)
151+
)
152+
else:
153+
# Default behavior: use the same prompt for all tasks
154+
query = step.format_prompt()
155+
task_requests = [
156+
TaskRequest(
157+
name=step.name,
158+
query=query,
159+
runtime_config=runtime_config,
160+
)
161+
] * task_count
162+
163+
return task_requests
164+
165+
async def run_pipeline(
166+
self, output_dir: str | PathLike = "output"
167+
) -> dict[str, Any]:
168+
"""Run the entire pipeline of steps."""
169+
os.makedirs(output_dir, exist_ok=True)
170+
171+
for i, step in enumerate(self.steps):
172+
print(f"Running step {i + 1}/{len(self.steps)}: {step.name}")
173+
if not step.upload_id:
174+
step.upload_id = f"{step.name}_{step.step_id}"
175+
176+
for source_path, dest_name in step.input_files.items():
177+
print(f"Uploading file {source_path} as {dest_name}")
178+
self.client.upload_file(
179+
step.name, file_path=source_path, upload_id=step.upload_id
180+
)
181+
182+
if step.config:
183+
runtime_config = RuntimeConfig(
184+
max_steps=step.config.max_steps,
185+
upload_id=step.upload_id,
186+
environment_config={
187+
"eval": step.config.eval,
188+
"language": step.config.language,
189+
},
190+
)
191+
else:
192+
runtime_config = None
193+
194+
task_requests = self._create_task_requests(step, runtime_config)
195+
196+
print(
197+
f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}"
198+
)
199+
task_responses = await self.client.arun_tasks_until_done(
200+
task_requests,
201+
progress_bar=True,
202+
verbose=True,
203+
timeout=step.config.timeout,
204+
)
205+
206+
task_ids = [str(task.task_id) for task in task_responses]
207+
success_rate = sum(
208+
[task.status == "success" for task in task_responses]
209+
) / len(task_responses)
210+
print(f"Task success rate: {success_rate * 100}%")
211+
212+
self.results[step.step_id] = {
213+
"task_ids": task_ids,
214+
"task_responses": task_responses,
215+
"success_rate": success_rate,
216+
}
217+
218+
os.makedirs(f"{output_dir}/{step.step_id}", exist_ok=True)
219+
220+
for idx, task_id in enumerate(task_ids):
221+
for source_name, dest_path in step.output_files.items():
222+
try:
223+
# Add index suffix only when there are multiple tasks
224+
path_suffix = f"_{idx}" if len(task_ids) > 1 else ""
225+
if "." in dest_path:
226+
base, ext = os.path.splitext(dest_path)
227+
dest_path_with_idx = f"{base}{path_suffix}{ext}"
228+
else:
229+
dest_path_with_idx = f"{dest_path}{path_suffix}"
230+
231+
path = f"{output_dir}/{step.step_id}/{dest_path_with_idx}"
232+
os.makedirs(
233+
os.path.dirname(os.path.abspath(path)), exist_ok=True
234+
)
235+
print(f"Downloading file {source_name} to {path}")
236+
self.client.download_file(
237+
step.name,
238+
trajectory_id=task_id,
239+
file_path=source_name,
240+
destination_path=path,
241+
)
242+
except Exception as e:
243+
print(
244+
f"Error downloading {source_name} from task {task_id}: {e}"
245+
)
246+
247+
if step.post_process:
248+
print(f"Running post-processing for step {step.step_id}")
249+
step.post_process(
250+
self.results[step.step_id], f"{output_dir}/{step.step_id}"
251+
)
252+
253+
print(f"Completed step {i + 1}/{len(self.steps)}")
254+
255+
self.save_results(output_dir)
256+
return self.results
257+
258+
def run(self, output_dir: str | PathLike = "output") -> dict[str, Any]:
259+
"""Synchronous version of run_pipeline."""
260+
return asyncio.run(self.run_pipeline(output_dir))

src/scripts/deploy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from futurehouse_client.models.app import TaskQueuesConfig
1414

15-
HIGH = False
15+
HIGH = True
1616
ENVIRONMENT = "DEV"
1717

1818
ENV_VARS = {
@@ -32,9 +32,9 @@
3232
FramePath(path="state.nb_state_html", type="notebook"),
3333
]
3434

35-
MODEL = "claude-3-7-sonnet-latest"
36-
TEMPERATURE = 1
37-
NUM_RETRIES = 3
35+
# MODEL = "claude-3-7-sonnet-latest"
36+
# TEMPERATURE = 1
37+
# NUM_RETRIES = 3
3838

3939
# agent = AgentConfig(
4040
# agent_type="ReActAgent",

0 commit comments

Comments
 (0)