Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ authors = [
]
dependencies = [
"aiodocker==0.24.0",
"anthropic==0.52.2", # this is necessary for tortoise, remove in favor of LMI when it works with search
"fhaviary[server]==0.19.0",
"ldp==0.26.0",
"pandas==2.2.3",
Expand All @@ -17,11 +18,12 @@ dependencies = [
"google-auth==2.38.0",
"google-cloud-storage==3.0.0",
"google-cloud-secret-manager==2.23.0",
"futurehouse-client==0.3.18",
"futurehouse-client==0.3.19",
"jupyter==1.1.1",
"nbconvert==7.16.6",
"notebook==7.3.2",
"nbformat==5.10.4"
"nbformat==5.10.4",
"seaborn==0.13.2"
]
description = "Data analysis crow"
name = "fhda"
Expand Down
227 changes: 122 additions & 105 deletions src/fhda/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,17 @@
wait_exponential,
retry_if_exception_type,
)
from . import prompts
from . import config as cfg

from futurehouse_client import FutureHouseClient
from futurehouse_client.models import TaskRequest, RuntimeConfig
from futurehouse_client.models.app import AuthType
from futurehouse_client.clients.rest_client import TaskFetchError
from futurehouse_client.models.app import AuthType, Stage
import anthropic
import logging
import traceback


class StepConfig(BaseModel):
"""Agent runtime configuration."""

language: str = Field(
default="PYTHON", description="Language for execution environment"
)
max_steps: int = Field(
default=30, description="Maximum number of steps for the agent"
)
timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds")
eval: bool = Field(
default=True,
description="For Finch, this indicates whether this is an API call or UI call. Setting it to True removes the automatic CoT additions.",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Step(BaseModel):
Expand All @@ -43,10 +32,16 @@ class Step(BaseModel):
name: str = Field(
description="Name of the job to run (e.g. 'job-futurehouse-data-analysis-crow-high')"
)
prompt_template: str = Field(description="Prompt template to use for the step")
cot_prompt: bool = Field(
default=False, description="Whether to augment the query with COT prompting"
llm_call: bool = Field(
default=False, description="Whether to call the LLM for the step"
)
include_search_tool: bool = Field(
default=False, description="Whether to include the search tool in the LLM call"
)
model_name: str = Field(
default=cfg.DEFAULT_MODEL, description="Name of the model to use for the step"
)
prompt_template: str = Field(description="Prompt template to use for the step")
prompt_args: dict[str, Any] = Field(
default_factory=dict,
description="Keyword arguments to format the prompt template.",
Expand All @@ -59,13 +54,14 @@ class Step(BaseModel):
description="Files to download {'source_name': 'dest_path'}",
)
step_id: str = Field(
default_factory=lambda: str(uuid.uuid4())[:8],
default_factory=lambda: str(uuid.uuid4()),
description="Small UID for the step",
)
upload_id: Optional[str] = Field(default=None, description="Upload ID for GCS")
parallel: int = Field(default=1, description="Number of parallel tasks to run")
config: StepConfig = Field(
default_factory=StepConfig, description="Configuration for the step"
n_replicate_tasks: int = Field(
default=1, description="Number of parallel tasks to run"
)
runtime_config: RuntimeConfig = Field(
default_factory=RuntimeConfig, description="Configuration for the step"
)
post_process: Optional[Callable[[dict[str, Any], str], None]] = Field(
default=None, description="Function to run after step completion"
Expand All @@ -74,36 +70,24 @@ class Step(BaseModel):
default=None,
description="Function to generate prompts and args for parallel tasks based on previous results",
)

def cot_prompting(self, query: str, language: str) -> str:
"""Apply chain-of-thought prompting to the query."""
guidelines = prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language)
if language == "R":
guidelines = prompts.R_SPECIFIC_GUIDELINES.format(language=language)
return (
f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language)}\n"
f"{guidelines}"
f"Here is the research question to address:\n"
f"<query>\n"
f"{query}\n"
f"</query>\n"
)
timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds")

def format_prompt(self) -> str:
"""Format the prompt template with the provided arguments."""
final_prompt = self.prompt_template.format(**self.prompt_args)
if self.cot_prompt:
final_prompt = self.cot_prompting(final_prompt, self.config.language)
return final_prompt


class Tortoise:
"""Runner for multi-step agent pipelines."""

def __init__(self, api_key: str):
def __init__(self, api_key: str, environment: str = "PROD"):
"""Initialize the tortoise framework with FutureHouse API key."""
self.client = FutureHouseClient(
auth_type=AuthType.API_KEY, api_key=api_key, verbose_logging=True
auth_type=AuthType.API_KEY,
api_key=api_key,
verbose_logging=True,
stage=getattr(Stage, environment.upper(), Stage.PROD),
)
self.steps: list[Step] = []
self.results: dict[str, Any] = {}
Expand All @@ -115,18 +99,18 @@ def add_step(self, step: Step) -> None:
def save_results(self, output_dir: str | PathLike = "output") -> None:
"""Save the results to a JSON file."""
results_path = f"{output_dir}/results_{time.strftime('%Y%m%d_%H%M%S')}.json"
print(f"Saving all results to {results_path}")
logger.info(f"Saving all results to {results_path}")
try:
os.makedirs(output_dir, exist_ok=True)
serializable_results = {}
for step_id, step_result in self.results.items():
serializable_results[step_id] = dict(step_result)

with open(results_path, "w") as f:
json.dump(serializable_results, f, indent=2)
print(f"Results successfully saved to {results_path}")
json.dump(serializable_results, f, indent=2, default=str)
logger.info(f"Results successfully saved to {results_path}")
except Exception as e:
print(f"Error saving results to {results_path}: {e}")
logger.error(f"Error saving results to {results_path}: {e}")

@retry(
stop=stop_after_attempt(3),
Expand Down Expand Up @@ -168,7 +152,21 @@ def _create_task_requests(
List of task requests to be executed
"""
task_requests = []
task_count = max(step.parallel, 1)
task_count = max(step.n_replicate_tasks, 1)

if step.model_name:
agent_config = cfg.get_custom_agent_config(step.model_name)
runtime_config.agent = agent_config

if step.runtime_config.continued_job_id:
task_ids = self.results[str(step.runtime_config.continued_job_id)][
"task_ids"
]
if len(task_ids) > 1:
logger.warning(
f"Continued job {step.runtime_config.continued_job_id} has multiple task ids, using the first one"
)
runtime_config.continued_job_id = str(task_ids[0])

if step.prompt_generator and task_count > 1:
# Generate dynamic prompts based on previous results
Expand Down Expand Up @@ -201,11 +199,34 @@ def _create_task_requests(

return task_requests

@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=30),
retry=retry_if_exception_type((Exception, TaskFetchError)),
)
async def call_llm(self, step: Step) -> list:
"""Call the LLM for the step."""
anthropic_client = anthropic.Anthropic()
# TODO: This is a hack to get the model name without the provider prefix
model_name = step.model_name.replace("anthropic/", "")
if step.include_search_tool:
tools = [
{
"type": "web_search_20250305",
"name": "web_search",
}
]
else:
tools = []
response = anthropic_client.messages.create(
model=model_name,
messages=[
{
"role": "user",
"content": step.prompt_template,
}
],
tools=tools,
max_tokens=8192,
)
result = "\n".join([r.text for r in response.content if hasattr(r, "text")])
return [result]

async def _run_tasks_with_retry(
self, task_requests, progress_bar, verbose, timeout
):
Expand All @@ -225,64 +246,60 @@ async def run_pipeline(
os.makedirs(output_dir, exist_ok=True)

for i, step in enumerate(self.steps):
print(f"Running step {i + 1}/{len(self.steps)}: {step.name}")
if not step.upload_id:
step.upload_id = f"{step.name}_{step.step_id}"
logger.info(f"Running step {i + 1}/{len(self.steps)}: {step.name}")
if not step.runtime_config.upload_id:
step.runtime_config.upload_id = step.step_id

for source_path, dest_name in step.input_files.items():
print(f"Uploading file {source_path} as {dest_name}")
logger.info(f"Uploading file {source_path} as {dest_name}")
try:
self._upload_file_with_retry(
step.name, file_path=source_path, upload_id=step.upload_id
step.name,
file_path=source_path,
upload_id=step.runtime_config.upload_id,
)
except Exception as e:
print(
logger.error(
f"Failed to upload file {source_path} after multiple retries: {e}"
)
raise

if step.config:
runtime_config = RuntimeConfig(
max_steps=step.config.max_steps,
upload_id=step.upload_id,
environment_config={
"eval": step.config.eval,
"language": step.config.language,
},
)
if step.llm_call:
task_responses = await self.call_llm(step)
task_ids = [f"llm_{str(uuid.uuid4())[:8]}"]
success_rate = 1
else:
runtime_config = None

task_requests = self._create_task_requests(step, runtime_config)

print(
f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}"
)
try:
task_responses = await self._run_tasks_with_retry(
task_requests,
progress_bar=True,
verbose=False,
timeout=step.config.timeout,
)
except Exception as e:
print(
f"Failed to run tasks for step {step.step_id} after multiple retries: {e}"
)
# Create an error result entry and continue to the next step
self.results[step.step_id] = {
"task_ids": [],
"task_responses": [],
"success_rate": 0,
"error": str(e),
}
continue
task_requests = self._create_task_requests(step, step.runtime_config)

task_ids = [str(task.task_id) for task in task_responses]
success_rate = sum(
[task.status == "success" for task in task_responses]
) / len(task_responses)
print(f"Task success rate: {success_rate * 100}%")
logger.info(
f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}"
)
try:
task_responses = await self._run_tasks_with_retry(
task_requests,
progress_bar=True,
verbose=False,
timeout=step.timeout,
)
except Exception as e:
logger.error(
f"Failed to run tasks for step {step.step_id} after multiple retries: {e}"
)
logger.error(f"Full traceback:\n{traceback.format_exc()}")
# Create an error result entry and continue to the next step
self.results[step.step_id] = {
"task_ids": [],
"task_responses": [],
"success_rate": 0,
"error": str(e),
}
continue

task_ids = [str(task.task_id) for task in task_responses]
success_rate = sum(
[task.status == "success" for task in task_responses]
) / len(task_responses)
logger.info(f"Task success rate: {success_rate * 100}%")

self.results[step.step_id] = {
"task_ids": task_ids,
Expand All @@ -307,7 +324,7 @@ async def run_pipeline(
os.makedirs(
os.path.dirname(os.path.abspath(path)), exist_ok=True
)
print(f"Downloading file {source_name} to {path}")
logger.info(f"Downloading file {source_name} to {path}")
try:
self._download_file_with_retry(
step.name,
Expand All @@ -316,21 +333,21 @@ async def run_pipeline(
destination_path=path,
)
except Exception as e:
print(
logger.error(
f"Failed to download {source_name} from task {task_id} after multiple retries: {e}"
)
except Exception as e:
print(
logger.error(
f"Error downloading {source_name} from task {task_id}: {e}"
)

if step.post_process:
print(f"Running post-processing for step {step.step_id}")
logger.info(f"Running post-processing for step {step.step_id}")
step.post_process(
self.results[step.step_id], f"{output_dir}/{step.step_id}"
)

print(f"Completed step {i + 1}/{len(self.steps)}")
logger.info(f"Completed step {i + 1}/{len(self.steps)}")

self.save_results(output_dir)
return self.results
Expand Down
6 changes: 3 additions & 3 deletions tutorial/consensus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
" max_steps=30,\n",
" upload_id=DEA_UPLOAD_ID,\n",
" environment_config={\n",
" \"eval\": True, # DO NOT CHANGE THIS\n",
" \"default_cot_prompt\": False,\n",
" \"language\": \"R\",\n",
" },\n",
")\n",
Expand Down Expand Up @@ -200,7 +200,7 @@
" max_steps=30,\n",
" upload_id=CONSENSUS_UPLOAD_ID,\n",
" environment_config={\n",
" \"eval\": True, # DO NOT CHANGE THIS\n",
" \"default_cot_prompt\": False,\n",
" \"language\": \"R\",\n",
" },\n",
")\n",
Expand Down Expand Up @@ -304,7 +304,7 @@
" max_steps=30,\n",
" upload_id=PQA_UPLOAD_ID,\n",
" environment_config={\n",
" \"eval\": True, # DO NOT CHANGE THIS\n",
" \"default_cot_prompt\": False,\n",
" \"language\": \"PYTHON\",\n",
" },\n",
")\n",
Expand Down
Loading