1313 wait_exponential ,
1414 retry_if_exception_type ,
1515)
16- from . import prompts
16+ from . import config as cfg
1717
1818from futurehouse_client import FutureHouseClient
1919from futurehouse_client .models import TaskRequest , RuntimeConfig
20- from futurehouse_client .models .app import AuthType
21- from futurehouse_client .clients .rest_client import TaskFetchError
20+ from futurehouse_client .models .app import AuthType , Stage
21+ import anthropic
22+ import logging
23+ import traceback
2224
23-
24- class StepConfig (BaseModel ):
25- """Agent runtime configuration."""
26-
27- language : str = Field (
28- default = "PYTHON" , description = "Language for execution environment"
29- )
30- max_steps : int = Field (
31- default = 30 , description = "Maximum number of steps for the agent"
32- )
33- timeout : int = Field (default = 15 * 60 , description = "Timeout for the step in seconds" )
34- eval : bool = Field (
35- default = True ,
36- description = "For Finch, this indicates whether this is an API call or UI call. Setting it to True removes the automatic CoT additions." ,
37- )
25+ logger = logging .getLogger (__name__ )
26+ logger .setLevel (logging .INFO )
3827
3928
4029class Step (BaseModel ):
@@ -43,10 +32,16 @@ class Step(BaseModel):
4332 name : str = Field (
4433 description = "Name of the job to run (e.g. 'job-futurehouse-data-analysis-crow-high')"
4534 )
46- prompt_template : str = Field (description = "Prompt template to use for the step" )
47- cot_prompt : bool = Field (
48- default = False , description = "Whether to augment the query with COT prompting"
35+ llm_call : bool = Field (
36+ default = False , description = "Whether to call the LLM for the step"
37+ )
38+ include_search_tool : bool = Field (
39+ default = False , description = "Whether to include the search tool in the LLM call"
4940 )
41+ model_name : str = Field (
42+ default = cfg .DEFAULT_MODEL , description = "Name of the model to use for the step"
43+ )
44+ prompt_template : str = Field (description = "Prompt template to use for the step" )
5045 prompt_args : dict [str , Any ] = Field (
5146 default_factory = dict ,
5247 description = "Keyword arguments to format the prompt template." ,
@@ -59,13 +54,14 @@ class Step(BaseModel):
5954 description = "Files to download {'source_name': 'dest_path'}" ,
6055 )
6156 step_id : str = Field (
62- default_factory = lambda : str (uuid .uuid4 ())[: 8 ] ,
57+ default_factory = lambda : str (uuid .uuid4 ()),
6358 description = "Small UID for the step" ,
6459 )
65- upload_id : Optional [str ] = Field (default = None , description = "Upload ID for GCS" )
66- parallel : int = Field (default = 1 , description = "Number of parallel tasks to run" )
67- config : StepConfig = Field (
68- default_factory = StepConfig , description = "Configuration for the step"
60+ n_replicate_tasks : int = Field (
61+ default = 1 , description = "Number of parallel tasks to run"
62+ )
63+ runtime_config : RuntimeConfig = Field (
64+ default_factory = RuntimeConfig , description = "Configuration for the step"
6965 )
7066 post_process : Optional [Callable [[dict [str , Any ], str ], None ]] = Field (
7167 default = None , description = "Function to run after step completion"
@@ -74,36 +70,24 @@ class Step(BaseModel):
7470 default = None ,
7571 description = "Function to generate prompts and args for parallel tasks based on previous results" ,
7672 )
77-
78- def cot_prompting (self , query : str , language : str ) -> str :
79- """Apply chain-of-thought prompting to the query."""
80- guidelines = prompts .GENERAL_NOTEBOOK_GUIDELINES .format (language = language )
81- if language == "R" :
82- guidelines = prompts .R_SPECIFIC_GUIDELINES .format (language = language )
83- return (
84- f"{ prompts .CHAIN_OF_THOUGHT_AGNOSTIC .format (language = language )} \n "
85- f"{ guidelines } "
86- f"Here is the research question to address:\n "
87- f"<query>\n "
88- f"{ query } \n "
89- f"</query>\n "
90- )
73+ timeout : int = Field (default = 15 * 60 , description = "Timeout for the step in seconds" )
9174
9275 def format_prompt (self ) -> str :
9376 """Format the prompt template with the provided arguments."""
9477 final_prompt = self .prompt_template .format (** self .prompt_args )
95- if self .cot_prompt :
96- final_prompt = self .cot_prompting (final_prompt , self .config .language )
9778 return final_prompt
9879
9980
10081class Tortoise :
10182 """Runner for multi-step agent pipelines."""
10283
103- def __init__ (self , api_key : str ):
84+ def __init__ (self , api_key : str , environment : str = "PROD" ):
10485 """Initialize the tortoise framework with FutureHouse API key."""
10586 self .client = FutureHouseClient (
106- auth_type = AuthType .API_KEY , api_key = api_key , verbose_logging = True
87+ auth_type = AuthType .API_KEY ,
88+ api_key = api_key ,
89+ verbose_logging = True ,
90+ stage = getattr (Stage , environment .upper (), Stage .PROD ),
10791 )
10892 self .steps : list [Step ] = []
10993 self .results : dict [str , Any ] = {}
@@ -115,18 +99,18 @@ def add_step(self, step: Step) -> None:
11599 def save_results (self , output_dir : str | PathLike = "output" ) -> None :
116100 """Save the results to a JSON file."""
117101 results_path = f"{ output_dir } /results_{ time .strftime ('%Y%m%d_%H%M%S' )} .json"
118- print (f"Saving all results to { results_path } " )
102+ logger . info (f"Saving all results to { results_path } " )
119103 try :
120104 os .makedirs (output_dir , exist_ok = True )
121105 serializable_results = {}
122106 for step_id , step_result in self .results .items ():
123107 serializable_results [step_id ] = dict (step_result )
124108
125109 with open (results_path , "w" ) as f :
126- json .dump (serializable_results , f , indent = 2 )
127- print (f"Results successfully saved to { results_path } " )
110+ json .dump (serializable_results , f , indent = 2 , default = str )
111+ logger . info (f"Results successfully saved to { results_path } " )
128112 except Exception as e :
129- print (f"Error saving results to { results_path } : { e } " )
113+ logger . error (f"Error saving results to { results_path } : { e } " )
130114
131115 @retry (
132116 stop = stop_after_attempt (3 ),
@@ -168,7 +152,21 @@ def _create_task_requests(
168152 List of task requests to be executed
169153 """
170154 task_requests = []
171- task_count = max (step .parallel , 1 )
155+ task_count = max (step .n_replicate_tasks , 1 )
156+
157+ if step .model_name :
158+ agent_config = cfg .get_custom_agent_config (step .model_name )
159+ runtime_config .agent = agent_config
160+
161+ if step .runtime_config .continued_job_id :
162+ task_ids = self .results [str (step .runtime_config .continued_job_id )][
163+ "task_ids"
164+ ]
165+ if len (task_ids ) > 1 :
166+ logger .warning (
167+ f"Continued job { step .runtime_config .continued_job_id } has multiple task ids, using the first one"
168+ )
169+ runtime_config .continued_job_id = str (task_ids [0 ])
172170
173171 if step .prompt_generator and task_count > 1 :
174172 # Generate dynamic prompts based on previous results
@@ -201,11 +199,34 @@ def _create_task_requests(
201199
202200 return task_requests
203201
204- @retry (
205- stop = stop_after_attempt (5 ),
206- wait = wait_exponential (multiplier = 1 , min = 2 , max = 30 ),
207- retry = retry_if_exception_type ((Exception , TaskFetchError )),
208- )
202+ async def call_llm (self , step : Step ) -> list :
203+ """Call the LLM for the step."""
204+ anthropic_client = anthropic .Anthropic ()
205+ # TODO: This is a hack to get the model name without the provider prefix
206+ model_name = step .model_name .replace ("anthropic/" , "" )
207+ if step .include_search_tool :
208+ tools = [
209+ {
210+ "type" : "web_search_20250305" ,
211+ "name" : "web_search" ,
212+ }
213+ ]
214+ else :
215+ tools = []
216+ response = anthropic_client .messages .create (
217+ model = model_name ,
218+ messages = [
219+ {
220+ "role" : "user" ,
221+ "content" : step .prompt_template ,
222+ }
223+ ],
224+ tools = tools ,
225+ max_tokens = 8192 ,
226+ )
227+ result = "\n " .join ([r .text for r in response .content if hasattr (r , "text" )])
228+ return [result ]
229+
209230 async def _run_tasks_with_retry (
210231 self , task_requests , progress_bar , verbose , timeout
211232 ):
@@ -225,64 +246,60 @@ async def run_pipeline(
225246 os .makedirs (output_dir , exist_ok = True )
226247
227248 for i , step in enumerate (self .steps ):
228- print (f"Running step { i + 1 } /{ len (self .steps )} : { step .name } " )
229- if not step .upload_id :
230- step .upload_id = f" { step .name } _ { step . step_id } "
249+ logger . info (f"Running step { i + 1 } /{ len (self .steps )} : { step .name } " )
250+ if not step .runtime_config . upload_id :
251+ step .runtime_config . upload_id = step .step_id
231252
232253 for source_path , dest_name in step .input_files .items ():
233- print (f"Uploading file { source_path } as { dest_name } " )
254+ logger . info (f"Uploading file { source_path } as { dest_name } " )
234255 try :
235256 self ._upload_file_with_retry (
236- step .name , file_path = source_path , upload_id = step .upload_id
257+ step .name ,
258+ file_path = source_path ,
259+ upload_id = step .runtime_config .upload_id ,
237260 )
238261 except Exception as e :
239- print (
262+ logger . error (
240263 f"Failed to upload file { source_path } after multiple retries: { e } "
241264 )
242265 raise
243266
244- if step .config :
245- runtime_config = RuntimeConfig (
246- max_steps = step .config .max_steps ,
247- upload_id = step .upload_id ,
248- environment_config = {
249- "eval" : step .config .eval ,
250- "language" : step .config .language ,
251- },
252- )
267+ if step .llm_call :
268+ task_responses = await self .call_llm (step )
269+ task_ids = [f"llm_{ str (uuid .uuid4 ())[:8 ]} " ]
270+ success_rate = 1
253271 else :
254- runtime_config = None
255-
256- task_requests = self ._create_task_requests (step , runtime_config )
257-
258- print (
259- f"Running { len (task_requests )} task{ 's' if len (task_requests ) > 1 else '' } "
260- )
261- try :
262- task_responses = await self ._run_tasks_with_retry (
263- task_requests ,
264- progress_bar = True ,
265- verbose = False ,
266- timeout = step .config .timeout ,
267- )
268- except Exception as e :
269- print (
270- f"Failed to run tasks for step { step .step_id } after multiple retries: { e } "
271- )
272- # Create an error result entry and continue to the next step
273- self .results [step .step_id ] = {
274- "task_ids" : [],
275- "task_responses" : [],
276- "success_rate" : 0 ,
277- "error" : str (e ),
278- }
279- continue
272+ task_requests = self ._create_task_requests (step , step .runtime_config )
280273
281- task_ids = [str (task .task_id ) for task in task_responses ]
282- success_rate = sum (
283- [task .status == "success" for task in task_responses ]
284- ) / len (task_responses )
285- print (f"Task success rate: { success_rate * 100 } %" )
274+ logger .info (
275+ f"Running { len (task_requests )} task{ 's' if len (task_requests ) > 1 else '' } "
276+ )
277+ try :
278+ task_responses = await self ._run_tasks_with_retry (
279+ task_requests ,
280+ progress_bar = True ,
281+ verbose = False ,
282+ timeout = step .timeout ,
283+ )
284+ except Exception as e :
285+ logger .error (
286+ f"Failed to run tasks for step { step .step_id } after multiple retries: { e } "
287+ )
288+ logger .error (f"Full traceback:\n { traceback .format_exc ()} " )
289+ # Create an error result entry and continue to the next step
290+ self .results [step .step_id ] = {
291+ "task_ids" : [],
292+ "task_responses" : [],
293+ "success_rate" : 0 ,
294+ "error" : str (e ),
295+ }
296+ continue
297+
298+ task_ids = [str (task .task_id ) for task in task_responses ]
299+ success_rate = sum (
300+ [task .status == "success" for task in task_responses ]
301+ ) / len (task_responses )
302+ logger .info (f"Task success rate: { success_rate * 100 } %" )
286303
287304 self .results [step .step_id ] = {
288305 "task_ids" : task_ids ,
@@ -307,7 +324,7 @@ async def run_pipeline(
307324 os .makedirs (
308325 os .path .dirname (os .path .abspath (path )), exist_ok = True
309326 )
310- print (f"Downloading file { source_name } to { path } " )
327+ logger . info (f"Downloading file { source_name } to { path } " )
311328 try :
312329 self ._download_file_with_retry (
313330 step .name ,
@@ -316,21 +333,21 @@ async def run_pipeline(
316333 destination_path = path ,
317334 )
318335 except Exception as e :
319- print (
336+ logger . error (
320337 f"Failed to download { source_name } from task { task_id } after multiple retries: { e } "
321338 )
322339 except Exception as e :
323- print (
340+ logger . error (
324341 f"Error downloading { source_name } from task { task_id } : { e } "
325342 )
326343
327344 if step .post_process :
328- print (f"Running post-processing for step { step .step_id } " )
345+ logger . info (f"Running post-processing for step { step .step_id } " )
329346 step .post_process (
330347 self .results [step .step_id ], f"{ output_dir } /{ step .step_id } "
331348 )
332349
333- print (f"Completed step { i + 1 } /{ len (self .steps )} " )
350+ logger . info (f"Completed step { i + 1 } /{ len (self .steps )} " )
334351
335352 self .save_results (output_dir )
336353 return self .results
0 commit comments