Skip to content

Commit dbe04ea

Browse files
authored
chore: valid selector update (#1248)
* chore: add medal info * return sota_exp_stat * update hit check * update experiment * udpate experiment * extract log * remove old log folder * update candidates * keep highest score * early stop if no medal candidate
1 parent 4dfb8a1 commit dbe04ea

File tree

1 file changed

+94
-24
lines changed
  • rdagent/scenarios/data_science/proposal/exp_gen/select

1 file changed

+94
-24
lines changed

rdagent/scenarios/data_science/proposal/exp_gen/select/submit.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import pickle
44
import re
55
import shutil
6+
import tarfile
67
import time
78
from pathlib import Path
89
from typing import Any, Dict, List, Optional, Tuple
910

1011
import fire
1112
import numpy as np
1213
import pandas as pd
14+
import yaml
1315
from loguru import logger
1416

1517
from rdagent.app.data_science.conf import DS_RD_SETTING
@@ -204,18 +206,43 @@ def get_sort_key(exp_fb: Tuple[DSExperiment, ExperimentFeedback]) -> Tuple[bool,
204206
# Sort key prioritizes decision (True > False), then score
205207
return (feedback.decision, score) if self.use_decision else score
206208

209+
def get_sort_key_without_decision(exp_fb: Tuple[DSExperiment, ExperimentFeedback]) -> Tuple[bool, float]:
210+
exp, feedback = exp_fb
211+
score = -np.inf
212+
if exp.result is not None:
213+
try:
214+
score = pd.DataFrame(exp.result).loc["ensemble"].iloc[0]
215+
if isinstance(score, str):
216+
score = float(score.strip("tensor()"))
217+
score = direction_sign * score
218+
except:
219+
logger.warning(f"Failed to extract score from result {exp.result}")
220+
221+
return score
222+
207223
# Collect candidates
208224
if self.each_trace:
209-
candidate_list = []
210-
leaves = trace.get_leaves()
211-
num_per_leaf = max(self.num_candidates // len(leaves), 1)
212-
for leaf in leaves:
213-
branch_experiments = trace.experiment_and_feedback_list_after_init(
214-
return_type="all", search_type="ancestors", selection=(leaf,)
215-
)
216-
if branch_experiments:
217-
branch_experiments.sort(key=get_sort_key, reverse=True)
218-
candidate_list.extend(branch_experiments[:num_per_leaf])
225+
# Add best experiment without decision
226+
hist = trace.hist.copy()
227+
hist.sort(key=get_sort_key_without_decision, reverse=True)
228+
candidate_list = [hist[0]]
229+
230+
root_to_experiments = {}
231+
for node in range(len(trace.hist)):
232+
parents = trace.get_parents(node)
233+
if parents:
234+
root = parents[0]
235+
if root not in root_to_experiments:
236+
root_to_experiments[root] = []
237+
root_to_experiments[root].append(trace.hist[node])
238+
239+
# Select top-k from each branch
240+
num_per_leaf = max(self.num_candidates // len(root_to_experiments), 2)
241+
for root, exps in root_to_experiments.items():
242+
if not exps:
243+
continue
244+
exps.sort(key=get_sort_key, reverse=True)
245+
candidate_list.extend(exps[:num_per_leaf])
219246
# Remove duplicates
220247
candidate_list = list(set(candidate_list))
221248
else:
@@ -226,7 +253,7 @@ def get_sort_key(exp_fb: Tuple[DSExperiment, ExperimentFeedback]) -> Tuple[bool,
226253
return None
227254

228255
# Sort and select the top N
229-
candidate_list.sort(key=get_sort_key, reverse=True)
256+
candidate_list.sort(key=get_sort_key_without_decision, reverse=True)
230257

231258
top_experiments = [exp for exp, _ in candidate_list[: self.num_candidates]]
232259
logger.info(f"BestValidSelector: Selected {len(top_experiments)} experiments.")
@@ -426,6 +453,8 @@ def _generate_and_run_script(
426453
ws = FBWorkspace()
427454
ws.inject_code_from_file_dict(reference_exp.experiment_workspace)
428455
ws.inject_files(**{f"{script_type}.py": generated_code})
456+
reference_code = reference_exp.experiment_workspace.file_dict.get("main.py", "")
457+
ws.inject_files(**{"reference_code.py": reference_code})
429458

430459
if script_type == "data":
431460
# For data.py, we need the original data to sample from
@@ -457,7 +486,7 @@ def _generate_and_run_script(
457486
extra_volumes={str(Path(mock_folder) / input_folder): {"bind": input_folder, "mode": "rw"}},
458487
running_timeout_period=DS_RD_SETTING.full_timeout,
459488
)
460-
result = ws.run(env=env, entry=f"python main.py --cache-buster={time.time()}")
489+
result = ws.run(env=env, entry=f"python reference_code.py")
461490
stdout = re.sub(r"^chmod:.*\n?", "", result.get_truncated_stdout(), flags=re.MULTILINE)
462491
if result.exit_code == 0:
463492
# move submission.csv to mock_folder
@@ -579,8 +608,11 @@ def check_hit(selected_exp: DSExperiment, trace: Trace, sota_result: Dict[str, A
579608
# Check by loop_id if available
580609
if hasattr(trace, "idx2loop_id"):
581610
loop_id = trace.idx2loop_id.get(index)
582-
if loop_id and loop_id in sota_result.get("medal_loops", []):
583-
return True
611+
if loop_id:
612+
if loop_id in sota_result.get("medal_loops", []):
613+
return True
614+
return False
615+
584616
# Fallback to checking by index
585617
if index in sota_result.get("medal_loops_index", []):
586618
return True
@@ -594,6 +626,11 @@ def try_get_loop_id(trace: Trace, exp: DSExperiment):
594626
return index
595627

596628

629+
def extract_tar(tar_path: str, to_dir: str = "log") -> str:
630+
with tarfile.open(tar_path, mode="r:*") as tar:
631+
tar.extractall(path=to_dir)
632+
633+
597634
# ==============================================================================
598635
# ## Main Orchestration Logic
599636
# ==============================================================================
@@ -609,13 +646,14 @@ def evaluate_one_trace(
609646
experiment: str = "validation",
610647
log_path: Path | None = None,
611648
sample_rate: float = 0.8,
612-
) -> Tuple[str, bool]:
649+
) -> Tuple[str, bool, str]:
613650
"""
614651
Loads a single trace, uses the specified selector to pick an experiment,
615652
and checks if the selection was a "hit" (a known SOTA solution).
616653
"""
617654
competition = trace.scen.competition
618655
hit = False
656+
sota_exp_stat = ""
619657

620658
# Example of scenario-specific adjustment
621659
if competition == "detecting-insults-in-social-commentary":
@@ -635,7 +673,7 @@ def evaluate_one_trace(
635673
if selector_name == "validation":
636674
if not Path(f"{DS_RD_SETTING.local_data_path}/{competition}").exists():
637675
logger.warning(f"Competition {DS_RD_SETTING.local_data_path}/{competition} does not exist, skipping.")
638-
return competition, False
676+
return competition, False, sota_exp_stat
639677
# The ValidationSelector is used to select the best re-test score.
640678
quick_selector = BestValidSelector(num_candidates=1, use_decision=True, each_trace=False)
641679
quick_selected_exps = quick_selector.get_sota_exp_to_submit(trace)
@@ -647,14 +685,25 @@ def evaluate_one_trace(
647685
candidate_exps = base_selector.collect_sota_candidates(trace)
648686
if not candidate_exps:
649687
logger.info("ValidationSelector: Base selector returned no candidates.")
650-
return competition, False
688+
return competition, False, sota_exp_stat
651689

652690
logger.info(f"ValidationSelector: Received {len(candidate_exps)} candidates for validation.")
691+
pool_hit = False
653692
if debug:
654693
pool_hit = any(check_hit(candidate_exp, trace, sota_result) for candidate_exp in candidate_exps)
655-
if not pool_hit:
656-
logger.info("ValidationSelector: Base selector's candidates did not hit any SOTA. Skipping validation.")
657-
return competition, False
694+
else:
695+
for exp in candidate_exps:
696+
loop_id = try_get_loop_id(trace, exp)
697+
sota_mle_score_paths = [i for i in log_path.rglob(f"Loop_{loop_id}/running/mle_score/**/*.pkl")]
698+
if len(sota_mle_score_paths):
699+
with sota_mle_score_paths[0].open("rb") as f:
700+
sota_mle_score = extract_json(pickle.load(f))
701+
if sota_mle_score.get("any_medal", False):
702+
pool_hit = True
703+
break
704+
if not pool_hit:
705+
logger.info("ValidationSelector: Selector's candidates did not hit any medal. Skipping validation.")
706+
return competition, False, sota_exp_stat
658707

659708
selector = ValidationSelector(
660709
candidate=[(exp, try_get_loop_id(trace, exp)) for exp in candidate_exps],
@@ -682,7 +731,14 @@ def evaluate_one_trace(
682731
with sota_mle_score_paths[0].open("rb") as f:
683732
sota_mle_score = extract_json(pickle.load(f))
684733
hit = sota_mle_score.get("any_medal", False)
685-
return competition, hit
734+
if hit:
735+
if sota_mle_score["gold_medal"]:
736+
sota_exp_stat = "gold"
737+
elif sota_mle_score["silver_medal"]:
738+
sota_exp_stat = "silver"
739+
elif sota_mle_score["bronze_medal"]:
740+
sota_exp_stat = "bronze"
741+
return competition, hit, sota_exp_stat
686742

687743

688744
def select_on_existing_trace(
@@ -712,11 +768,21 @@ def select_on_existing_trace(
712768

713769
# Prepare list of tasks for multiprocessing
714770
tasks = []
771+
if debug and experiment and "yaml" in trace_root:
772+
job_info = yaml.safe_load(open(str(Path(trace_root) / f"{experiment}.yaml"), "r"))
773+
if not competition:
774+
competition = os.getenv("DS_COMPETITION")
775+
for job in job_info:
776+
if job["submit_args"]["env"]["DS_COMPETITION"] == competition:
777+
tar_file = Path("/mnt/output") / job["results_dir"] / job["submit_args"]["env"]["RD_RES_NAME"]
778+
extract_tar(tar_file)
779+
debug = False
780+
715781
if debug:
716782
for trace_folder in trace_root_path.iterdir():
717783
if not trace_folder.is_dir():
718784
continue
719-
if experiment is not None:
785+
if isinstance(experiment, str) and experiment:
720786
if trace_folder.name not in experiment:
721787
continue
722788
for trace_pkl_path in trace_folder.glob("*.pkl"):
@@ -779,13 +845,15 @@ def select_on_existing_trace(
779845
hit_list = multiprocessing_wrapper(tasks, n=1) # n=1 for sequential debugging, increase for parallel runs
780846

781847
# Aggregate and report results
782-
hit_count = sum(hit for _, hit in hit_list if hit is not None)
848+
hit_count = sum(hit for _, hit, _ in hit_list if hit is not None)
783849
total_valid_traces = len(hit_list)
784850

785851
print("\n" + "=" * 50)
786852
print(f"Evaluation Summary for Selector: '{selector_name}'")
787853
print(f"Total Traces Processed: {total_valid_traces}")
788854
print(f"Total Hits: {hit_count}")
855+
if not debug and hit_count:
856+
print(f"Medal info: {hit_list[0][2]}")
789857
if total_valid_traces > 0:
790858
hit_rate = (hit_count / total_valid_traces) * 100
791859
print(f"Hit Rate: {hit_rate:.2f}%")
@@ -796,11 +864,13 @@ def select_on_existing_trace(
796864
"total": total_valid_traces,
797865
"hit_rate": hit_rate if total_valid_traces > 0 else 0,
798866
}
799-
result_dict["details"] = [{comp: hit} for comp, hit in hit_list]
867+
result_dict["details"] = [{comp: hit} for comp, hit, _ in hit_list]
800868

801869
with open(f"result_{selector_name}.json", "w") as f:
802870
json.dump(result_dict, f, indent=4)
803871
logger.info(f"Results saved to result_{selector_name}.json")
872+
if "yaml" in trace_root and Path("log/log").exists():
873+
shutil.rmtree("log/log")
804874

805875

806876
if __name__ == "__main__":

0 commit comments

Comments
 (0)