Skip to content

Commit 4ebcbf3

Browse files
authored
Tortoise and notebook updates (#24)
1 parent 095df6e commit 4ebcbf3

File tree

6 files changed

+151
-121
lines changed

6 files changed

+151
-121
lines changed

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ authors = [
88
]
99
dependencies = [
1010
"aiodocker==0.24.0",
11+
"anthropic==0.52.2", # this is necessary for tortoise, remove in favor of LMI when it works with search
1112
"fhaviary[server]==0.19.0",
1213
"ldp==0.26.0",
1314
"pandas==2.2.3",
@@ -17,11 +18,12 @@ dependencies = [
1718
"google-auth==2.38.0",
1819
"google-cloud-storage==3.0.0",
1920
"google-cloud-secret-manager==2.23.0",
20-
"futurehouse-client==0.3.18",
21+
"futurehouse-client==0.3.19",
2122
"jupyter==1.1.1",
2223
"nbconvert==7.16.6",
2324
"notebook==7.3.2",
24-
"nbformat==5.10.4"
25+
"nbformat==5.10.4",
26+
"seaborn==0.13.2"
2527
]
2628
description = "Data analysis crow"
2729
name = "fhda"

src/fhda/tortoise.py

Lines changed: 122 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,17 @@
1313
wait_exponential,
1414
retry_if_exception_type,
1515
)
16-
from . import prompts
16+
from . import config as cfg
1717

1818
from futurehouse_client import FutureHouseClient
1919
from futurehouse_client.models import TaskRequest, RuntimeConfig
20-
from futurehouse_client.models.app import AuthType
21-
from futurehouse_client.clients.rest_client import TaskFetchError
20+
from futurehouse_client.models.app import AuthType, Stage
21+
import anthropic
22+
import logging
23+
import traceback
2224

23-
24-
class StepConfig(BaseModel):
25-
"""Agent runtime configuration."""
26-
27-
language: str = Field(
28-
default="PYTHON", description="Language for execution environment"
29-
)
30-
max_steps: int = Field(
31-
default=30, description="Maximum number of steps for the agent"
32-
)
33-
timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds")
34-
eval: bool = Field(
35-
default=True,
36-
description="For Finch, this indicates whether this is an API call or UI call. Setting it to True removes the automatic CoT additions.",
37-
)
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.INFO)
3827

3928

4029
class Step(BaseModel):
@@ -43,10 +32,16 @@ class Step(BaseModel):
4332
name: str = Field(
4433
description="Name of the job to run (e.g. 'job-futurehouse-data-analysis-crow-high')"
4534
)
46-
prompt_template: str = Field(description="Prompt template to use for the step")
47-
cot_prompt: bool = Field(
48-
default=False, description="Whether to augment the query with COT prompting"
35+
llm_call: bool = Field(
36+
default=False, description="Whether to call the LLM for the step"
37+
)
38+
include_search_tool: bool = Field(
39+
default=False, description="Whether to include the search tool in the LLM call"
4940
)
41+
model_name: str = Field(
42+
default=cfg.DEFAULT_MODEL, description="Name of the model to use for the step"
43+
)
44+
prompt_template: str = Field(description="Prompt template to use for the step")
5045
prompt_args: dict[str, Any] = Field(
5146
default_factory=dict,
5247
description="Keyword arguments to format the prompt template.",
@@ -59,13 +54,14 @@ class Step(BaseModel):
5954
description="Files to download {'source_name': 'dest_path'}",
6055
)
6156
step_id: str = Field(
62-
default_factory=lambda: str(uuid.uuid4())[:8],
57+
default_factory=lambda: str(uuid.uuid4()),
6358
description="Small UID for the step",
6459
)
65-
upload_id: Optional[str] = Field(default=None, description="Upload ID for GCS")
66-
parallel: int = Field(default=1, description="Number of parallel tasks to run")
67-
config: StepConfig = Field(
68-
default_factory=StepConfig, description="Configuration for the step"
60+
n_replicate_tasks: int = Field(
61+
default=1, description="Number of parallel tasks to run"
62+
)
63+
runtime_config: RuntimeConfig = Field(
64+
default_factory=RuntimeConfig, description="Configuration for the step"
6965
)
7066
post_process: Optional[Callable[[dict[str, Any], str], None]] = Field(
7167
default=None, description="Function to run after step completion"
@@ -74,36 +70,24 @@ class Step(BaseModel):
7470
default=None,
7571
description="Function to generate prompts and args for parallel tasks based on previous results",
7672
)
77-
78-
def cot_prompting(self, query: str, language: str) -> str:
79-
"""Apply chain-of-thought prompting to the query."""
80-
guidelines = prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language)
81-
if language == "R":
82-
guidelines = prompts.R_SPECIFIC_GUIDELINES.format(language=language)
83-
return (
84-
f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language)}\n"
85-
f"{guidelines}"
86-
f"Here is the research question to address:\n"
87-
f"<query>\n"
88-
f"{query}\n"
89-
f"</query>\n"
90-
)
73+
timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds")
9174

9275
def format_prompt(self) -> str:
9376
"""Format the prompt template with the provided arguments."""
9477
final_prompt = self.prompt_template.format(**self.prompt_args)
95-
if self.cot_prompt:
96-
final_prompt = self.cot_prompting(final_prompt, self.config.language)
9778
return final_prompt
9879

9980

10081
class Tortoise:
10182
"""Runner for multi-step agent pipelines."""
10283

103-
def __init__(self, api_key: str):
84+
def __init__(self, api_key: str, environment: str = "PROD"):
10485
"""Initialize the tortoise framework with FutureHouse API key."""
10586
self.client = FutureHouseClient(
106-
auth_type=AuthType.API_KEY, api_key=api_key, verbose_logging=True
87+
auth_type=AuthType.API_KEY,
88+
api_key=api_key,
89+
verbose_logging=True,
90+
stage=getattr(Stage, environment.upper(), Stage.PROD),
10791
)
10892
self.steps: list[Step] = []
10993
self.results: dict[str, Any] = {}
@@ -115,18 +99,18 @@ def add_step(self, step: Step) -> None:
11599
def save_results(self, output_dir: str | PathLike = "output") -> None:
116100
"""Save the results to a JSON file."""
117101
results_path = f"{output_dir}/results_{time.strftime('%Y%m%d_%H%M%S')}.json"
118-
print(f"Saving all results to {results_path}")
102+
logger.info(f"Saving all results to {results_path}")
119103
try:
120104
os.makedirs(output_dir, exist_ok=True)
121105
serializable_results = {}
122106
for step_id, step_result in self.results.items():
123107
serializable_results[step_id] = dict(step_result)
124108

125109
with open(results_path, "w") as f:
126-
json.dump(serializable_results, f, indent=2)
127-
print(f"Results successfully saved to {results_path}")
110+
json.dump(serializable_results, f, indent=2, default=str)
111+
logger.info(f"Results successfully saved to {results_path}")
128112
except Exception as e:
129-
print(f"Error saving results to {results_path}: {e}")
113+
logger.error(f"Error saving results to {results_path}: {e}")
130114

131115
@retry(
132116
stop=stop_after_attempt(3),
@@ -168,7 +152,21 @@ def _create_task_requests(
168152
List of task requests to be executed
169153
"""
170154
task_requests = []
171-
task_count = max(step.parallel, 1)
155+
task_count = max(step.n_replicate_tasks, 1)
156+
157+
if step.model_name:
158+
agent_config = cfg.get_custom_agent_config(step.model_name)
159+
runtime_config.agent = agent_config
160+
161+
if step.runtime_config.continued_job_id:
162+
task_ids = self.results[str(step.runtime_config.continued_job_id)][
163+
"task_ids"
164+
]
165+
if len(task_ids) > 1:
166+
logger.warning(
167+
f"Continued job {step.runtime_config.continued_job_id} has multiple task ids, using the first one"
168+
)
169+
runtime_config.continued_job_id = str(task_ids[0])
172170

173171
if step.prompt_generator and task_count > 1:
174172
# Generate dynamic prompts based on previous results
@@ -201,11 +199,34 @@ def _create_task_requests(
201199

202200
return task_requests
203201

204-
@retry(
205-
stop=stop_after_attempt(5),
206-
wait=wait_exponential(multiplier=1, min=2, max=30),
207-
retry=retry_if_exception_type((Exception, TaskFetchError)),
208-
)
202+
async def call_llm(self, step: Step) -> list:
203+
"""Call the LLM for the step."""
204+
anthropic_client = anthropic.Anthropic()
205+
# TODO: This is a hack to get the model name without the provider prefix
206+
model_name = step.model_name.replace("anthropic/", "")
207+
if step.include_search_tool:
208+
tools = [
209+
{
210+
"type": "web_search_20250305",
211+
"name": "web_search",
212+
}
213+
]
214+
else:
215+
tools = []
216+
response = anthropic_client.messages.create(
217+
model=model_name,
218+
messages=[
219+
{
220+
"role": "user",
221+
"content": step.prompt_template,
222+
}
223+
],
224+
tools=tools,
225+
max_tokens=8192,
226+
)
227+
result = "\n".join([r.text for r in response.content if hasattr(r, "text")])
228+
return [result]
229+
209230
async def _run_tasks_with_retry(
210231
self, task_requests, progress_bar, verbose, timeout
211232
):
@@ -225,64 +246,60 @@ async def run_pipeline(
225246
os.makedirs(output_dir, exist_ok=True)
226247

227248
for i, step in enumerate(self.steps):
228-
print(f"Running step {i + 1}/{len(self.steps)}: {step.name}")
229-
if not step.upload_id:
230-
step.upload_id = f"{step.name}_{step.step_id}"
249+
logger.info(f"Running step {i + 1}/{len(self.steps)}: {step.name}")
250+
if not step.runtime_config.upload_id:
251+
step.runtime_config.upload_id = step.step_id
231252

232253
for source_path, dest_name in step.input_files.items():
233-
print(f"Uploading file {source_path} as {dest_name}")
254+
logger.info(f"Uploading file {source_path} as {dest_name}")
234255
try:
235256
self._upload_file_with_retry(
236-
step.name, file_path=source_path, upload_id=step.upload_id
257+
step.name,
258+
file_path=source_path,
259+
upload_id=step.runtime_config.upload_id,
237260
)
238261
except Exception as e:
239-
print(
262+
logger.error(
240263
f"Failed to upload file {source_path} after multiple retries: {e}"
241264
)
242265
raise
243266

244-
if step.config:
245-
runtime_config = RuntimeConfig(
246-
max_steps=step.config.max_steps,
247-
upload_id=step.upload_id,
248-
environment_config={
249-
"eval": step.config.eval,
250-
"language": step.config.language,
251-
},
252-
)
267+
if step.llm_call:
268+
task_responses = await self.call_llm(step)
269+
task_ids = [f"llm_{str(uuid.uuid4())[:8]}"]
270+
success_rate = 1
253271
else:
254-
runtime_config = None
255-
256-
task_requests = self._create_task_requests(step, runtime_config)
257-
258-
print(
259-
f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}"
260-
)
261-
try:
262-
task_responses = await self._run_tasks_with_retry(
263-
task_requests,
264-
progress_bar=True,
265-
verbose=False,
266-
timeout=step.config.timeout,
267-
)
268-
except Exception as e:
269-
print(
270-
f"Failed to run tasks for step {step.step_id} after multiple retries: {e}"
271-
)
272-
# Create an error result entry and continue to the next step
273-
self.results[step.step_id] = {
274-
"task_ids": [],
275-
"task_responses": [],
276-
"success_rate": 0,
277-
"error": str(e),
278-
}
279-
continue
272+
task_requests = self._create_task_requests(step, step.runtime_config)
280273

281-
task_ids = [str(task.task_id) for task in task_responses]
282-
success_rate = sum(
283-
[task.status == "success" for task in task_responses]
284-
) / len(task_responses)
285-
print(f"Task success rate: {success_rate * 100}%")
274+
logger.info(
275+
f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}"
276+
)
277+
try:
278+
task_responses = await self._run_tasks_with_retry(
279+
task_requests,
280+
progress_bar=True,
281+
verbose=False,
282+
timeout=step.timeout,
283+
)
284+
except Exception as e:
285+
logger.error(
286+
f"Failed to run tasks for step {step.step_id} after multiple retries: {e}"
287+
)
288+
logger.error(f"Full traceback:\n{traceback.format_exc()}")
289+
# Create an error result entry and continue to the next step
290+
self.results[step.step_id] = {
291+
"task_ids": [],
292+
"task_responses": [],
293+
"success_rate": 0,
294+
"error": str(e),
295+
}
296+
continue
297+
298+
task_ids = [str(task.task_id) for task in task_responses]
299+
success_rate = sum(
300+
[task.status == "success" for task in task_responses]
301+
) / len(task_responses)
302+
logger.info(f"Task success rate: {success_rate * 100}%")
286303

287304
self.results[step.step_id] = {
288305
"task_ids": task_ids,
@@ -307,7 +324,7 @@ async def run_pipeline(
307324
os.makedirs(
308325
os.path.dirname(os.path.abspath(path)), exist_ok=True
309326
)
310-
print(f"Downloading file {source_name} to {path}")
327+
logger.info(f"Downloading file {source_name} to {path}")
311328
try:
312329
self._download_file_with_retry(
313330
step.name,
@@ -316,21 +333,21 @@ async def run_pipeline(
316333
destination_path=path,
317334
)
318335
except Exception as e:
319-
print(
336+
logger.error(
320337
f"Failed to download {source_name} from task {task_id} after multiple retries: {e}"
321338
)
322339
except Exception as e:
323-
print(
340+
logger.error(
324341
f"Error downloading {source_name} from task {task_id}: {e}"
325342
)
326343

327344
if step.post_process:
328-
print(f"Running post-processing for step {step.step_id}")
345+
logger.info(f"Running post-processing for step {step.step_id}")
329346
step.post_process(
330347
self.results[step.step_id], f"{output_dir}/{step.step_id}"
331348
)
332349

333-
print(f"Completed step {i + 1}/{len(self.steps)}")
350+
logger.info(f"Completed step {i + 1}/{len(self.steps)}")
334351

335352
self.save_results(output_dir)
336353
return self.results

tutorial/consensus.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
" max_steps=30,\n",
142142
" upload_id=DEA_UPLOAD_ID,\n",
143143
" environment_config={\n",
144-
" \"eval\": True, # DO NOT CHANGE THIS\n",
144+
" \"default_cot_prompt\": False,\n",
145145
" \"language\": \"R\",\n",
146146
" },\n",
147147
")\n",
@@ -200,7 +200,7 @@
200200
" max_steps=30,\n",
201201
" upload_id=CONSENSUS_UPLOAD_ID,\n",
202202
" environment_config={\n",
203-
" \"eval\": True, # DO NOT CHANGE THIS\n",
203+
" \"default_cot_prompt\": False,\n",
204204
" \"language\": \"R\",\n",
205205
" },\n",
206206
")\n",
@@ -304,7 +304,7 @@
304304
" max_steps=30,\n",
305305
" upload_id=PQA_UPLOAD_ID,\n",
306306
" environment_config={\n",
307-
" \"eval\": True, # DO NOT CHANGE THIS\n",
307+
" \"default_cot_prompt\": False,\n",
308308
" \"language\": \"PYTHON\",\n",
309309
" },\n",
310310
")\n",

0 commit comments

Comments
 (0)