|
| 1 | +import json |
| 2 | +import pickle |
| 3 | +import time |
| 4 | +from datetime import datetime, timedelta |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +import streamlit as st |
| 8 | +from streamlit import session_state as state |
| 9 | + |
| 10 | +from rdagent.app.data_science.conf import DS_RD_SETTING |
| 11 | + |
| 12 | +st.set_page_config(layout="wide", page_title="RD-Agent_user_interact", page_icon="🎓", initial_sidebar_state="expanded") |
| 13 | + |
| 14 | +# 初始化session state |
| 15 | +if "sessions" not in state: |
| 16 | + state.sessions = {} |
| 17 | +if "selected_session_name" not in state: |
| 18 | + state.selected_session_name = None |
| 19 | + |
| 20 | + |
| 21 | +def render_main_content(): |
| 22 | + """渲染主要内容区域""" |
| 23 | + if state.selected_session_name is not None and state.selected_session_name in state.sessions: |
| 24 | + selected_session_data = state.sessions[state.selected_session_name] |
| 25 | + if selected_session_data is not None: |
| 26 | + st.title( |
| 27 | + f"Session: {state.selected_session_name[:4]} with competition {selected_session_data['competition']}" |
| 28 | + ) |
| 29 | + st.title("Contextual Information:") |
| 30 | + st.subheader("Competition scenario:", divider=True) |
| 31 | + scenario = st.code(selected_session_data["scenario_description"], language="yaml") |
| 32 | + st.subheader("Former attempts summary:", divider=True) |
| 33 | + scenario = st.code(selected_session_data["ds_trace_desc"], language="yaml") |
| 34 | + if selected_session_data["current_code"] != "": |
| 35 | + st.subheader("Current SOTA code", divider=True) |
| 36 | + scenario = st.code( |
| 37 | + body=selected_session_data["current_code"], |
| 38 | + language="python", |
| 39 | + ) |
| 40 | + |
| 41 | + st.subheader("Hypothesis candidates:", divider=True) |
| 42 | + hypothesis_candidates = selected_session_data["hypothesis_candidates"] |
| 43 | + tabs = st.tabs( |
| 44 | + [ |
| 45 | + f"{'✅' if i == selected_session_data['target_hypothesis_index'] or selected_session_data['target_hypothesis_index'] == -1 else ''}Hypothesis {i+1}" |
| 46 | + for i in range(len(hypothesis_candidates)) |
| 47 | + ] |
| 48 | + ) |
| 49 | + for index, hypothesis in enumerate(hypothesis_candidates): |
| 50 | + with tabs[index]: |
| 51 | + st.code(str(hypothesis), language="yaml") |
| 52 | + st.text("✅ means picked as target hypothesis") |
| 53 | + |
| 54 | + st.title("Decisions to make:") |
| 55 | + |
| 56 | + with st.form(key="user_form"): |
| 57 | + st.caption("Please modify the fields below and submit to provide your feedback.") |
| 58 | + target_hypothesis = st.text_area( |
| 59 | + "Target hypothesis: (you can copy from candidates)", |
| 60 | + value=(original_hypothesis := selected_session_data["target_hypothesis"].hypothesis), |
| 61 | + height="content", |
| 62 | + ) |
| 63 | + target_task = st.text_area( |
| 64 | + "Target task description:", |
| 65 | + value=(original_task_desc := selected_session_data["task"].description), |
| 66 | + height="content", |
| 67 | + ) |
| 68 | + original_user_instruction = selected_session_data.get("user_instruction") |
| 69 | + user_instruction_list = [] |
| 70 | + if selected_session_data.get("former_user_instructions") is not None: |
| 71 | + st.caption( |
| 72 | + "Former user instructions, you can modify or delete the content to remove certain instruction." |
| 73 | + ) |
| 74 | + for user_instruction in selected_session_data.get("former_user_instructions"): |
| 75 | + user_instruction_list.append( |
| 76 | + st.text_area("Former user instruction", value=user_instruction, height="content") |
| 77 | + ) |
| 78 | + user_instruction_list.append(st.text_area("Add new user instruction", value="", height="content")) |
| 79 | + submit = st.form_submit_button("Submit") |
| 80 | + approve = st.form_submit_button("Approve without changes") |
| 81 | + |
| 82 | + if submit or approve: |
| 83 | + if approve: |
| 84 | + submit_dict = { |
| 85 | + "action": "confirm", |
| 86 | + } |
| 87 | + else: |
| 88 | + user_instruction_str_list = [ui for ui in user_instruction_list if ui.strip() != ""] |
| 89 | + user_instruction_str_list = ( |
| 90 | + None if len(user_instruction_str_list) == 0 else user_instruction_str_list |
| 91 | + ) |
| 92 | + action = ( |
| 93 | + "confirm" |
| 94 | + if target_hypothesis == original_hypothesis |
| 95 | + and target_task == original_task_desc |
| 96 | + and user_instruction_str_list == original_user_instruction |
| 97 | + else "rewrite" |
| 98 | + ) |
| 99 | + submit_dict = { |
| 100 | + "target_hypothesis": target_hypothesis, |
| 101 | + "task_description": target_task, |
| 102 | + "user_instruction": user_instruction_str_list, |
| 103 | + "action": action, |
| 104 | + } |
| 105 | + json.dump( |
| 106 | + submit_dict, |
| 107 | + open( |
| 108 | + DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}_RET.json", "w" |
| 109 | + ), |
| 110 | + ) |
| 111 | + Path(DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}.pkl").unlink( |
| 112 | + missing_ok=True |
| 113 | + ) |
| 114 | + st.success("Your feedback has been submitted. Thank you!") |
| 115 | + time.sleep(5) |
| 116 | + state.selected_session_name = None |
| 117 | + |
| 118 | + if st.button("Extend expiration by 60s"): |
| 119 | + session_data = pickle.load( |
| 120 | + open(DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}.pkl", "rb") |
| 121 | + ) |
| 122 | + session_data["expired_datetime"] = session_data["expired_datetime"] + timedelta(seconds=60) |
| 123 | + pickle.dump( |
| 124 | + session_data, |
| 125 | + open(DS_RD_SETTING.user_interaction_mid_folder / f"{state.selected_session_name}.pkl", "wb"), |
| 126 | + ) |
| 127 | + else: |
| 128 | + st.warning("Please select a session from the sidebar.") |
| 129 | + |
| 130 | + |
| 131 | +# 每秒更新一次sessions |
| 132 | +@st.fragment(run_every=1) |
| 133 | +def update_sessions(): |
| 134 | + log_folder = Path(DS_RD_SETTING.user_interaction_mid_folder) |
| 135 | + state.sessions = {} |
| 136 | + for session_file in log_folder.glob("*.pkl"): |
| 137 | + try: |
| 138 | + session_data = pickle.load(open(session_file, "rb")) |
| 139 | + if session_data["expired_datetime"] > datetime.now(): |
| 140 | + state.sessions[session_file.stem] = session_data |
| 141 | + else: |
| 142 | + session_file.unlink(missing_ok=True) |
| 143 | + ret_file = log_folder / f"{session_file.stem}_RET.json" |
| 144 | + ret_file.unlink(missing_ok=True) |
| 145 | + except Exception as e: |
| 146 | + continue |
| 147 | + render_main_content() |
| 148 | + |
| 149 | + |
| 150 | +@st.fragment(run_every=1) |
| 151 | +def render_sidebar(): |
| 152 | + st.title("R&D-Agent User Interaction Portal") |
| 153 | + if state.sessions: |
| 154 | + st.header("Active Sessions") |
| 155 | + st.caption("Click a session to view:") |
| 156 | + session_names = [name for name in state.sessions] |
| 157 | + for session_name in session_names: |
| 158 | + with st.container(border=True): |
| 159 | + remaining = state.sessions[session_name]["expired_datetime"] - datetime.now() |
| 160 | + total_sec = int(remaining.total_seconds()) |
| 161 | + label = f"{total_sec}s to expire" if total_sec > 0 else "Expired" |
| 162 | + if st.button(f"session id:{session_name[:4]}", key=f"session_btn_{session_name}"): |
| 163 | + state.selected_session_name = session_name |
| 164 | + state.data = state.sessions[session_name] |
| 165 | + st.markdown(f"⏳ {label}") |
| 166 | + else: |
| 167 | + st.warning("No active sessions available. Please wait.") |
| 168 | + |
| 169 | + |
| 170 | +update_sessions() |
| 171 | +with st.sidebar: |
| 172 | + render_sidebar() |
0 commit comments