Skip to content

Commit 1942958

Browse files
authored
Continuations (#20)
1 parent 9d4a075 commit 1942958

File tree

9 files changed

+233
-51
lines changed

9 files changed

+233
-51
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ repos:
66
rev: 25.1.0
77
hooks:
88
- id: black-jupyter
9+
types: [jupyter]
910
- repo: https://github.com/pre-commit/pre-commit-hooks
1011
rev: v5.0.0
1112
hooks:

src/fhda/Dockerfile.custom_deployment

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ WORKDIR /app
55
ENV PYTHONUNBUFFERED=1
66
ENV DEBIAN_FRONTEND=noninteractive
77

8-
RUN --mount=type=cache,target=/var/cache/apt \
9-
apt-get update -qq && \
8+
RUN apt-get update -qq && \
109
apt-get install -yq --no-install-recommends \
1110
git \
1211
openssh-client \
@@ -28,15 +27,13 @@ ENV PATH="/app/miniconda/bin:$PATH"
2827
ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}"
2928

3029
# Install uv & mamba
31-
RUN --mount=type=cache,target=/root/.cache/pip \
32-
pip3 install --no-cache-dir uv==0.5.21
33-
RUN --mount=type=cache,target=/app/miniconda/pkgs \
34-
conda install -c conda-forge mamba -y
30+
RUN pip3 install --no-cache-dir uv==0.5.21
31+
RUN conda install -c conda-forge mamba -y
3532

3633
# Install R and kernels in the crow_env environment
37-
RUN --mount=type=cache,target=/app/miniconda/pkgs \
38-
mamba install -c conda-forge -y \
34+
RUN mamba install -c conda-forge -y \
3935
r-base=4.3.3 \
36+
r-r.utils=2.13.0 \
4037
r-recommended=4.3 \
4138
r-irkernel=1.3.2 \
4239
r-factominer=2.11 \
@@ -86,13 +83,10 @@ RUN --mount=type=cache,target=/app/miniconda/pkgs \
8683
statsmodels=0.14.4 \
8784
umap-learn=0.5.7
8885

89-
RUN --mount=type=cache,target=/app/miniconda/pkgs \
90-
python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)"
91-
RUN --mount=type=cache,target=/app/miniconda/pkgs \
92-
R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")'
86+
RUN python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)"
87+
RUN R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")'
9388

94-
RUN --mount=type=cache,target=/app/miniconda/pkgs \
95-
mamba install -c conda-forge -c bioconda -y \
89+
RUN mamba install -c conda-forge -c bioconda -y \
9690
biokit=0.5.0 \
9791
gseapy=1.1.4 \
9892
blast=2.16.0 \
@@ -116,7 +110,9 @@ RUN --mount=type=cache,target=/app/miniconda/pkgs \
116110
bioconductor-summarizedexperiment=1.32.0 \
117111
bioconductor-apeglm=1.24.0 \
118112
bioconductor-flowcore=2.14.0 \
119-
bioconductor-flowmeans=1.62.0
113+
bioconductor-flowmeans=1.62.0 \
114+
bioconductor-limma=3.58.1 \
115+
bioconductor-geoquery=2.70.0
120116

121117
ENV UV_COMPILE_BYTECODE=1
122118
ENV UV_LINK_MODE=copy
@@ -131,8 +127,7 @@ FROM base AS builder
131127

132128
ARG MODULE_NAME
133129

134-
RUN --mount=type=cache,target=/var/cache/apt \
135-
apt-get update -qq && \
130+
RUN apt-get update -qq && \
136131
apt-get install -yq --no-install-recommends \
137132
build-essential && \
138133
apt-get clean && rm -rf /var/lib/apt/lists/*
@@ -147,9 +142,7 @@ COPY ./scripts/run_crow_job.py /app/scripts/
147142

148143
# Install application dependencies (this will only rerun when code changes)
149144
WORKDIR /app/${MODULE_NAME}
150-
RUN --mount=type=cache,target=/root/.cache/uv \
151-
--mount=type=cache,target=/app/miniconda/pkgs \
152-
if [ -f "pyproject.toml" ]; then \
145+
RUN if [ -f "pyproject.toml" ]; then \
153146
uv pip install --system -e .; \
154147
elif [ -f "requirements.txt" ]; then \
155148
uv pip install --system -r requirements.txt; \
@@ -167,4 +160,4 @@ COPY --from=builder /app/ /app/
167160
ENV VIRTUAL_ENV="/app/miniconda/bin"
168161
ENV PATH="/app/miniconda/bin:$PATH"
169162
ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}"
170-
CMD ["python", "scripts/run_crow_job.py"]
163+
CMD ["python", "scripts/run_crow_job.py"]

src/fhda/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@
2525
# FutureHosue client config
2626
ENVIRONMENT = os.getenv("ENVIRONMENT", "prod")
2727
CROW_STAGE = getattr(Stage, ENVIRONMENT.upper(), Stage.PROD)
28-
PLATFORM_API_KEY = os.getenv("CROW_API_KEY", None)
28+
PLATFORM_API_KEY = os.getenv("FH_API_KEY", None)

src/fhda/data_analysis_env.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from futurehouse_client import FutureHouseClient
1616

1717
from .notebook_env import NBEnvironment
18-
from .utils import NBLanguage, MultipleChoiceQuestion
18+
from .utils import NBLanguage, MultipleChoiceQuestion, extract_xml_content
1919
from . import prompts
2020
from . import config as cfg
2121

@@ -174,6 +174,7 @@ def from_task(
174174
trajectory_id: str | None = None,
175175
user_id: str | None = None,
176176
environment_config: dict[str, Any] | None = None,
177+
continued_trajectory_id: str | None = None,
177178
) -> "DataAnalysisEnv":
178179
"""
179180
Perform data analysis on a user query.
@@ -188,18 +189,21 @@ def from_task(
188189
logger.info("environment_config: %s", environment_config)
189190
logger.info("trajectory_id: %s", trajectory_id)
190191
logger.info("user_id: %s", user_id)
191-
# Track cost of running the environment
192+
logger.info("continued_trajectory_id: %s", continued_trajectory_id)
192193
enable_cost_tracking()
194+
193195
if (
194-
not gcs_artifact_path
196+
(not gcs_artifact_path) and not continued_trajectory_id
195197
): # Platform jobs should always be associated with data from a GCS bucket
196198
raise NotImplementedError(
197199
"Running crow jobs without gcs_artifact_path is not supported"
198200
)
199201

200202
if user_id is None:
203+
logger.warning("No user_id provided, using default_user")
201204
user_id = "default_user"
202205
if trajectory_id is None:
206+
logger.warning("No trajectory_id provided, using time-based id")
203207
trajectory_id = f"{gcs_artifact_path}-{time.time()}"
204208
if environment_config:
205209
kwargs = {
@@ -214,11 +218,49 @@ def from_task(
214218
trajectory_path = (
215219
cfg.DATA_STORAGE_PATH / "user_trajectories" / user_id / trajectory_id
216220
)
217-
if environment_config.get("gcs_override", False):
218-
data_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path
221+
222+
if continued_trajectory_id:
223+
kwargs["rerun_all_cells"] = True
224+
data_path = (
225+
cfg.DATA_STORAGE_PATH
226+
/ "user_trajectories"
227+
/ user_id
228+
/ continued_trajectory_id
229+
)
230+
logger.info("Continuing trajectory from %s", continued_trajectory_id)
231+
if cfg.PLATFORM_API_KEY is None:
232+
logger.warning(
233+
"Platform API key is not set, can't fetch previous trajectory"
234+
)
235+
previous_research_question = None
236+
previous_final_answer = None
237+
else:
238+
logger.info("Fetching previous trajectory")
239+
client = FutureHouseClient(
240+
stage=cfg.CROW_STAGE,
241+
auth_type=AuthType.API_KEY,
242+
api_key=cfg.PLATFORM_API_KEY,
243+
)
244+
previous_trajectory = client.get_task(
245+
continued_trajectory_id, verbose=True
246+
)
247+
previous_research_question = extract_xml_content(
248+
previous_trajectory.query, "query"
249+
)
250+
previous_final_answer = previous_trajectory.environment_frame["state"][
251+
"state"
252+
]["answer"]
253+
language = previous_trajectory.environment_frame["state"]["info"][
254+
"language"
255+
]
256+
language = getattr(NBLanguage, language.upper())
257+
kwargs["language"] = language
258+
259+
elif environment_config.get("gcs_override", False):
260+
data_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path # type: ignore
219261
else:
220262
data_path = (
221-
cfg.DATA_STORAGE_PATH / "user_data" / user_id / gcs_artifact_path
263+
cfg.DATA_STORAGE_PATH / "user_data" / user_id / gcs_artifact_path # type: ignore
222264
)
223265
logger.info("Trajectory path: %s", trajectory_path)
224266
logger.info("Data path: %s", data_path)
@@ -230,12 +272,19 @@ def from_task(
230272
shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True)
231273
logger.info("Filtered kwargs: %s", kwargs)
232274

233-
language = getattr(NBLanguage, environment_config.get("language", "PYTHON"))
234-
# Overwrite the language in the kwargs with NBLanguage enum
235-
kwargs["language"] = language
236-
logger.info("Language: %s", language.name)
275+
# If it's continued, we already have the language
276+
if continued_trajectory_id:
277+
logger.info(
278+
"Language already set from previous trajectory notebook %s",
279+
kwargs.get("language", None),
280+
)
281+
else:
282+
language = getattr(NBLanguage, environment_config.get("language", "PYTHON"))
283+
# Overwrite the language in the kwargs with NBLanguage enum
284+
kwargs["language"] = language
285+
logger.info("Language: %s", language.name)
237286

238-
if not environment_config.get("eval", False):
287+
if not environment_config.get("eval", False) and not continued_trajectory_id:
239288
logger.info(
240289
"Platform job detected, augmenting user query with CoT instructions"
241290
)
@@ -248,6 +297,17 @@ def from_task(
248297
f"{task}\n"
249298
f"</query>\n"
250299
)
300+
if continued_trajectory_id and not environment_config.get("eval", False):
301+
logger.info(
302+
"Continuation job detected, augmenting user query with continuation instructions"
303+
)
304+
task = prompts.CONTINUATION_PROMPT_TEMPLATE.format(
305+
previous_research_question=previous_research_question,
306+
previous_final_answer=previous_final_answer,
307+
query=task,
308+
language=kwargs.get("language", "PYTHON"),
309+
)
310+
251311
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
252312
logger.info("NB path: %s", nb_path)
253313

src/fhda/notebook_env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
language: utils.NBLanguage = utils.NBLanguage.PYTHON,
125125
allow_download_from_gcs: bool = False,
126126
run_notebook_on_edit: bool = False,
127+
rerun_all_cells: bool = False,
127128
):
128129
"""Initialize a notebook environment.
129130
@@ -140,6 +141,7 @@ def __init__(
140141
task requires data on GCS. Disabled by default.
141142
run_notebook_on_edit: If True (default), the whole notebook will be rerun
142143
after each edit. If False, only the cell that was edited will be rerun.
144+
rerun_all_cells: If True, the whole notebook will be run at the beginning of the episode. This is for continued trajectories.
143145
"""
144146
self.work_dir = Path(work_dir)
145147
self.nb_path = Path(nb_path) if nb_path else self.work_dir / self.NOTEBOOK_NAME
@@ -149,6 +151,7 @@ def __init__(
149151
self.allow_download_from_gcs = allow_download_from_gcs
150152
self.use_docker = cfg.USE_DOCKER
151153
self.run_notebook_on_edit = run_notebook_on_edit
154+
self.rerun_all_cells = rerun_all_cells
152155

153156
async def reset(self) -> tuple[Messages, list[Tool]]:
154157
nb_path, work_dir = self._set_work_dir()
@@ -158,6 +161,8 @@ async def reset(self) -> tuple[Messages, list[Tool]]:
158161
language=self.language,
159162
use_docker=self.use_docker,
160163
)
164+
if self.rerun_all_cells:
165+
await self.run_notebook()
161166

162167
self.tools = [
163168
Tool.from_function(self.edit_cell),

src/fhda/prompts.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,28 @@
238238
{GENERAL_NOTEBOOK_GUIDELINES}
239239
{R_SPECIFIC_GUIDELINES}
240240
"""
241+
242+
CONTINUATION_PROMPT_TEMPLATE = f"""
243+
{GENERAL_NOTEBOOK_GUIDELINES}
244+
245+
You have been provided with a notebook previously generated by an agent based on a user's research question.
246+
247+
This was the user's research question:
248+
<previous_research_question>
249+
{{previous_research_question}}
250+
</previous_research_question>
251+
252+
This was the final answer generated by the previous agent:
253+
<previous_final_answer>
254+
{{previous_final_answer}}
255+
</previous_final_answer>
256+
257+
The user has now tasked you with addressing a new query:
258+
<query>
259+
{{query}}
260+
</query>
261+
262+
Please make any edits required to the notebook and the answer to address the new query. Be extremely diligent and ensure that the notebook is fully updated to address the new query.
263+
Note you may have to run all cells one by one again if the user query involved updating one of the intermediate cells and subsequent cells depend on it.
264+
Once you have updated the notebook, use the submit_answer tool to submit your final answer once the user's query is addressed.
265+
"""

0 commit comments

Comments
 (0)