Skip to content

Commit 2afef70

Browse files
peteryang1XianBW
andauthored
feat: add user interaction in data science scenario (#1251)
* feat: add interactor classes and user interaction handling for experiments * update code * use fragment retry mechanism instead of rerun() * fix a bug * integrate user instructions into proposal and coder * fix CI * fix CI * feat: add approval option for user instructions submission * feat: enhance user instructions handling in Task and DSExperiment classes * fix CI * add user instructions into hypothesis rewrite * add interface to command line --------- Co-authored-by: Bowen Xian <[email protected]>
1 parent dbe04ea commit 2afef70

File tree

13 files changed

+448
-12
lines changed

13 files changed

+448
-12
lines changed

rdagent/app/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def server_ui(port=19899):
6161
subprocess.run(["python", "rdagent/log/server/app.py", f"--port={port}"])
6262

6363

64+
def ds_user_interact(port=19900):
65+
"""
66+
start web app to show the log traces in real time
67+
"""
68+
commands = ["streamlit", "run", "rdagent/log/ui/ds_user_interact.py", f"--server.port={port}"]
69+
subprocess.run(commands)
70+
71+
6472
app.command(name="fin_factor")(fin_factor)
6573
app.command(name="fin_model")(fin_model)
6674
app.command(name="fin_quant")(fin_quant)
@@ -72,6 +80,7 @@ def server_ui(port=19899):
7280
app.command(name="server_ui")(server_ui)
7381
app.command(name="health_check")(health_check)
7482
app.command(name="collect_info")(collect_info)
83+
app.command(name="ds_user_interact")(ds_user_interact)
7584

7685

7786
if __name__ == "__main__":

rdagent/app/data_science/conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
from typing import Literal
23

34
from pydantic_settings import SettingsConfigDict
@@ -20,6 +21,7 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
2021

2122
planner: str = "rdagent.scenarios.data_science.proposal.exp_gen.planner.DSExpPlannerHandCraft"
2223
hypothesis_gen: str = "rdagent.scenarios.data_science.proposal.exp_gen.router.ParallelMultiTraceExpGen"
24+
interactor: str = "rdagent.components.interactor.SkipInteractor"
2325
trace_scheduler: str = "rdagent.scenarios.data_science.proposal.exp_gen.trace_scheduler.RoundRobinScheduler"
2426
"""Hypothesis generation class"""
2527

@@ -182,6 +184,9 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
182184

183185
ensemble_time_upper_bound: bool = False
184186

187+
user_interaction_wait_seconds: int = 6000 # seconds to wait for user interaction
188+
user_interaction_mid_folder: Path = Path.cwd() / "git_ignore_folder" / "RD-Agent_user_interaction"
189+
185190

186191
DS_RD_SETTING = DataScienceBasePropSetting()
187192

rdagent/components/coder/data_science/pipeline/prompts.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ pipeline_coder:
169169
8. **Try-except blocks are ONLY allowed when reading files. If no files are successfully read, it indicates incorrect file paths or reading methods, not a try-except issue. Try-except is PROHIBITED elsewhere in the code. Assert statements are PROHIBITED throughout the entire code.**
170170
9. ATTENTION: ALWAYS use the best saved model (not necessarily final epoch) for predictions. **NEVER create dummy/placeholder submissions (e.g., all 1s, random values)**. If training fails, report failure honestly rather than generating fake submission files.
171171
10. You should ALWAYS generate the complete code rather than partial code.
172-
11. Strictly follow all specifications and general guidelines described above.
172+
11. If the task contains any user instructions, you must strictly follow them. User instructions have the highest priority and should be followed even if they conflict with other specifications or guidelines.
173+
12. Strictly follow all specifications and general guidelines described above.
173174
174175
### Output Format
175176
{% if out_spec %}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from rdagent.core.experiment import ASpecificExp
2+
from rdagent.core.interactor import Interactor
3+
from rdagent.core.proposal import Trace
4+
5+
6+
class SkipInteractor(Interactor[ASpecificExp]):
7+
8+
def interact(self, exp: ASpecificExp, trace: Trace) -> ASpecificExp:
9+
"""
10+
Interact with the user to get feedback or confirmation.
11+
12+
Responsibilities:
13+
- Present the current state of the experiment to the user.
14+
- Collect user input to guide the next steps in the experiment.
15+
- Rewrite the experiment based on user feedback.
16+
"""
17+
return exp

rdagent/components/workflow/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class BasePropSetting(ExtendedBaseSettings):
1111
knowledge_base: str = ""
1212
knowledge_base_path: str = ""
1313
hypothesis_gen: str = ""
14+
interactor: str = ""
1415
hypothesis2experiment: str = ""
1516
coder: str = ""
1617
runner: str = ""

rdagent/core/experiment.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from copy import deepcopy
1414
from dataclasses import dataclass
1515
from pathlib import Path
16-
from typing import TYPE_CHECKING, Any, Generic, TypeVar
16+
from typing import TYPE_CHECKING, Any, Generic, List, TypeVar
1717

1818
from rdagent.core.conf import RD_AGENT_SETTINGS
1919
from rdagent.core.evaluation import Feedback
@@ -48,13 +48,28 @@ def get_task_information(self) -> str:
4848
"""
4949

5050

51+
class UserInstructions(List[str]):
52+
def __str__(self) -> str:
53+
if self:
54+
return ("\nUser Instructions (Top priority!):\n" + "\n".join(f"- {ui}" for ui in self)) if self else ""
55+
else:
56+
return ""
57+
58+
5159
class Task(AbsTask):
52-
def __init__(self, name: str, version: int = 1, description: str = "") -> None:
60+
def __init__(
61+
self,
62+
name: str,
63+
version: int = 1,
64+
description: str = "",
65+
user_instructions: UserInstructions | None = None,
66+
) -> None:
5367
super().__init__(name, version)
5468
self.description = description
69+
self.user_instructions = user_instructions
5570

5671
def get_task_information(self) -> str:
57-
return f"Task Name: {self.name}\nDescription: {self.description}"
72+
return f"Task Name: {self.name}\nDescription: {self.description}{str(self.user_instructions)}"
5873

5974
def __repr__(self) -> str:
6075
return f"<{self.__class__.__name__} {self.name}>"
@@ -410,6 +425,21 @@ def __init__(
410425
self.plan: ExperimentPlan | None = (
411426
None # To store the planning information for this experiment, should be generated inside exp_gen.gen
412427
)
428+
self.user_instructions: UserInstructions | None = None # To store the user instructions for this experiment
429+
430+
def set_user_instructions(self, user_instructions: UserInstructions | None) -> None:
431+
if user_instructions is None:
432+
return
433+
if not isinstance(user_instructions, UserInstructions) and isinstance(user_instructions, list):
434+
user_instructions = UserInstructions(user_instructions)
435+
self.user_instructions = user_instructions
436+
for ws in self.sub_workspace_list:
437+
if ws is not None:
438+
ws.target_task.user_instructions = user_instructions # type: ignore[union-attr]
439+
for task in self.sub_tasks:
440+
task.user_instructions = user_instructions
441+
if self.experiment_workspace is not None and self.experiment_workspace.target_task is not None:
442+
self.experiment_workspace.target_task.user_instructions = user_instructions
413443

414444
@property
415445
def result(self) -> object:

rdagent/core/interactor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from trace import Trace
5+
from typing import TYPE_CHECKING, Generic
6+
7+
from rdagent.core.experiment import ASpecificExp
8+
9+
if TYPE_CHECKING:
10+
from rdagent.core.scenario import Scenario
11+
12+
13+
class Interactor(ABC, Generic[ASpecificExp]):
14+
def __init__(self, scen: Scenario) -> None:
15+
self.scen: Scenario = scen
16+
17+
@abstractmethod
18+
def interact(self, exp: ASpecificExp, trace: Trace | None = None) -> ASpecificExp:
19+
"""
20+
Interact with the experiment to get feedback or confirmation.
21+
22+
Responsibilities:
23+
- Present the current state of the experiment.
24+
- Collect input to guide the next steps in the experiment.
25+
- Rewrite the experiment based on feedback.
26+
"""

rdagent/log/ui/ds_user_interact.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import json
2+
import pickle
3+
import time
4+
from datetime import datetime, timedelta
5+
from pathlib import Path
6+
7+
import streamlit as st
8+
from streamlit import session_state as state
9+
10+
from rdagent.app.data_science.conf import DS_RD_SETTING
11+
12+
st.set_page_config(layout="wide", page_title="RD-Agent_user_interact", page_icon="🎓", initial_sidebar_state="expanded")
13+
14+
# 初始化session state
15+
if "sessions" not in state:
16+
state.sessions = {}
17+
if "selected_session_name" not in state:
18+
state.selected_session_name = None
19+
20+
21+
def render_main_content():
22+
"""渲染主要内容区域"""
23+
if state.selected_session_name is not None and state.selected_session_name in state.sessions:
24+
selected_session_data = state.sessions[state.selected_session_name]
25+
if selected_session_data is not None:
26+
st.title(
27+
f"Session: {state.selected_session_name[:4]} with competition {selected_session_data['competition']}"
28+
)
29+
st.title("Contextual Information:")
30+
st.subheader("Competition scenario:", divider=True)
31+
scenario = st.code(selected_session_data["scenario_description"], language="yaml")
32+
st.subheader("Former attempts summary:", divider=True)
33+
scenario = st.code(selected_session_data["ds_trace_desc"], language="yaml")
34+
if selected_session_data["current_code"] != "":
35+
st.subheader("Current SOTA code", divider=True)
36+
scenario = st.code(
37+
body=selected_session_data["current_code"],
38+
language="python",
39+
)
40+
41+
st.subheader("Hypothesis candidates:", divider=True)
42+
hypothesis_candidates = selected_session_data["hypothesis_candidates"]
43+
tabs = st.tabs(
44+
[
45+
f"{'✅' if i == selected_session_data['target_hypothesis_index'] or selected_session_data['target_hypothesis_index'] == -1 else ''}Hypothesis {i+1}"
46+
for i in range(len(hypothesis_candidates))
47+
]
48+
)
49+
for index, hypothesis in enumerate(hypothesis_candidates):
50+
with tabs[index]:
51+
st.code(str(hypothesis), language="yaml")
52+
st.text("✅ means picked as target hypothesis")
53+
54+
st.title("Decisions to make:")
55+
56+
with st.form(key="user_form"):
57+
st.caption("Please modify the fields below and submit to provide your feedback.")
58+
target_hypothesis = st.text_area(
59+
"Target hypothesis: (you can copy from candidates)",
60+
value=(original_hypothesis := selected_session_data["target_hypothesis"].hypothesis),
61+
height="content",
62+
)
63+
target_task = st.text_area(
64+
"Target task description:",
65+
value=(original_task_desc := selected_session_data["task"].description),
66+
height="content",
67+
)
68+
original_user_instruction = selected_session_data.get("user_instruction")
69+
user_instruction_list = []
70+
if selected_session_data.get("former_user_instructions") is not None:
71+
st.caption(
72+
"Former user instructions, you can modify or delete the content to remove certain instruction."
73+
)
74+
for user_instruction in selected_session_data.get("former_user_instructions"):
75+
user_instruction_list.append(
76+
st.text_area("Former user instruction", value=user_instruction, height="content")
77+
)
78+
user_instruction_list.append(st.text_area("Add new user instruction", value="", height="content"))
79+
submit = st.form_submit_button("Submit")
80+
approve = st.form_submit_button("Approve without changes")
81+
82+
if submit or approve:
83+
if approve:
84+
submit_dict = {
85+
"action": "confirm",
86+
}
87+
else:
88+
user_instruction_str_list = [ui for ui in user_instruction_list if ui.strip() != ""]
89+
user_instruction_str_list = (
90+
None if len(user_instruction_str_list) == 0 else user_instruction_str_list
91+
)
92+
action = (
93+
"confirm"
94+
if target_hypothesis == original_hypothesis
95+
and target_task == original_task_desc
96+
and user_instruction_str_list == original_user_instruction
97+
else "rewrite"
98+
)
99+
submit_dict = {
100+
"target_hypothesis": target_hypothesis,
101+
"task_description": target_task,
102+
"user_instruction": user_instruction_str_list,
103+
"action": action,
104+
}
105+
json.dump(
106+
submit_dict,
107+
open(
108+
DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}_RET.json", "w"
109+
),
110+
)
111+
Path(DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}.pkl").unlink(
112+
missing_ok=True
113+
)
114+
st.success("Your feedback has been submitted. Thank you!")
115+
time.sleep(5)
116+
state.selected_session_name = None
117+
118+
if st.button("Extend expiration by 60s"):
119+
session_data = pickle.load(
120+
open(DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}.pkl", "rb")
121+
)
122+
session_data["expired_datetime"] = session_data["expired_datetime"] + timedelta(seconds=60)
123+
pickle.dump(
124+
session_data,
125+
open(DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}.pkl", "wb"),
126+
)
127+
else:
128+
st.warning("Please select a session from the sidebar.")
129+
130+
131+
# 每秒更新一次sessions
132+
@st.fragment(run_every=1)
133+
def update_sessions():
134+
log_folder = Path(DS_RD_SETTING.user_interaction_mid_folder)
135+
state.sessions = {}
136+
for session_file in log_folder.glob("*.pkl"):
137+
try:
138+
session_data = pickle.load(open(session_file, "rb"))
139+
if session_data["expired_datetime"] > datetime.now():
140+
state.sessions[session_file.stem] = session_data
141+
else:
142+
session_file.unlink(missing_ok=True)
143+
ret_file = log_folder / f"{session_file.stem}_RET.json"
144+
ret_file.unlink(missing_ok=True)
145+
except Exception as e:
146+
continue
147+
render_main_content()
148+
149+
150+
@st.fragment(run_every=1)
151+
def render_sidebar():
152+
st.title("R&D-Agent User Interaction Portal")
153+
if state.sessions:
154+
st.header("Active Sessions")
155+
st.caption("Click a session to view:")
156+
session_names = [name for name in state.sessions]
157+
for session_name in session_names:
158+
with st.container(border=True):
159+
remaining = state.sessions[session_name]["expired_datetime"] - datetime.now()
160+
total_sec = int(remaining.total_seconds())
161+
label = f"{total_sec}s to expire" if total_sec > 0 else "Expired"
162+
if st.button(f"session id:{session_name[:4]}", key=f"session_btn_{session_name}"):
163+
state.selected_session_name = session_name
164+
state.data = state.sessions[session_name]
165+
st.markdown(f"⏳ {label}")
166+
else:
167+
st.warning("No active sessions available. Please wait.")
168+
169+
170+
update_sessions()
171+
with st.sidebar:
172+
render_sidebar()

rdagent/scenarios/data_science/experiment/experiment.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import pandas as pd
55

6-
from rdagent.core.experiment import Experiment, FBWorkspace, Task
6+
from rdagent.core.experiment import Experiment, FBWorkspace, Task, UserInstructions
77

88
COMPONENT = Literal["DataLoadSpec", "FeatureEng", "Model", "Ensemble", "Workflow", "Pipeline"]
99

1010

1111
class DSExperiment(Experiment[Task, FBWorkspace, FBWorkspace]):
12-
def __init__(self, pending_tasks_list: list, *args, **kwargs) -> None:
12+
def __init__(self, pending_tasks_list: list, hypothesis_candidates: list | None = None, *args, **kwargs) -> None:
1313
super().__init__(sub_tasks=[], *args, **kwargs)
1414
# Status
1515
# - Initial: blank;
@@ -18,11 +18,20 @@ def __init__(self, pending_tasks_list: list, *args, **kwargs) -> None:
1818
# the initial workspace or the successful new version after coding
1919
self.experiment_workspace = FBWorkspace()
2020
self.pending_tasks_list = pending_tasks_list
21+
self.hypothesis_candidates = hypothesis_candidates
2122

2223
self.format_check_result = None
2324
# this field is optional. It is not none only when we have a format checker. Currently, only following cases are supported.
2425
# - mle-bench
2526

27+
def set_user_instructions(self, user_instructions: UserInstructions | None):
28+
super().set_user_instructions(user_instructions)
29+
if user_instructions is None:
30+
return
31+
for task_list in self.pending_tasks_list:
32+
for task in task_list:
33+
task.user_instructions = user_instructions
34+
2635
def is_ready_to_run(self) -> bool:
2736
"""
2837
ready to run does not indicate the experiment is runnable

0 commit comments

Comments
 (0)