33import pickle
44import re
55import shutil
6+ import tarfile
67import time
78from pathlib import Path
89from typing import Any , Dict , List , Optional , Tuple
910
1011import fire
1112import numpy as np
1213import pandas as pd
14+ import yaml
1315from loguru import logger
1416
1517from rdagent .app .data_science .conf import DS_RD_SETTING
@@ -204,18 +206,43 @@ def get_sort_key(exp_fb: Tuple[DSExperiment, ExperimentFeedback]) -> Tuple[bool,
204206 # Sort key prioritizes decision (True > False), then score
205207 return (feedback .decision , score ) if self .use_decision else score
206208
209+ def get_sort_key_without_decision (exp_fb : Tuple [DSExperiment , ExperimentFeedback ]) -> Tuple [bool , float ]:
210+ exp , feedback = exp_fb
211+ score = - np .inf
212+ if exp .result is not None :
213+ try :
214+ score = pd .DataFrame (exp .result ).loc ["ensemble" ].iloc [0 ]
215+ if isinstance (score , str ):
216+ score = float (score .strip ("tensor()" ))
217+ score = direction_sign * score
218+ except :
219+ logger .warning (f"Failed to extract score from result { exp .result } " )
220+
221+ return score
222+
207223 # Collect candidates
208224 if self .each_trace :
209- candidate_list = []
210- leaves = trace .get_leaves ()
211- num_per_leaf = max (self .num_candidates // len (leaves ), 1 )
212- for leaf in leaves :
213- branch_experiments = trace .experiment_and_feedback_list_after_init (
214- return_type = "all" , search_type = "ancestors" , selection = (leaf ,)
215- )
216- if branch_experiments :
217- branch_experiments .sort (key = get_sort_key , reverse = True )
218- candidate_list .extend (branch_experiments [:num_per_leaf ])
225+ # Add best experiment without decision
226+ hist = trace .hist .copy ()
227+ hist .sort (key = get_sort_key_without_decision , reverse = True )
228+ candidate_list = [hist [0 ]]
229+
230+ root_to_experiments = {}
231+ for node in range (len (trace .hist )):
232+ parents = trace .get_parents (node )
233+ if parents :
234+ root = parents [0 ]
235+ if root not in root_to_experiments :
236+ root_to_experiments [root ] = []
237+ root_to_experiments [root ].append (trace .hist [node ])
238+
239+ # Select top-k from each branch
240+ num_per_leaf = max (self .num_candidates // len (root_to_experiments ), 2 )
241+ for root , exps in root_to_experiments .items ():
242+ if not exps :
243+ continue
244+ exps .sort (key = get_sort_key , reverse = True )
245+ candidate_list .extend (exps [:num_per_leaf ])
219246 # Remove duplicates
220247 candidate_list = list (set (candidate_list ))
221248 else :
@@ -226,7 +253,7 @@ def get_sort_key(exp_fb: Tuple[DSExperiment, ExperimentFeedback]) -> Tuple[bool,
226253 return None
227254
228255 # Sort and select the top N
229- candidate_list .sort (key = get_sort_key , reverse = True )
256+ candidate_list .sort (key = get_sort_key_without_decision , reverse = True )
230257
231258 top_experiments = [exp for exp , _ in candidate_list [: self .num_candidates ]]
232259 logger .info (f"BestValidSelector: Selected { len (top_experiments )} experiments." )
@@ -426,6 +453,8 @@ def _generate_and_run_script(
426453 ws = FBWorkspace ()
427454 ws .inject_code_from_file_dict (reference_exp .experiment_workspace )
428455 ws .inject_files (** {f"{ script_type } .py" : generated_code })
456+ reference_code = reference_exp .experiment_workspace .file_dict .get ("main.py" , "" )
457+ ws .inject_files (** {"reference_code.py" : reference_code })
429458
430459 if script_type == "data" :
431460 # For data.py, we need the original data to sample from
@@ -457,7 +486,7 @@ def _generate_and_run_script(
457486 extra_volumes = {str (Path (mock_folder ) / input_folder ): {"bind" : input_folder , "mode" : "rw" }},
458487 running_timeout_period = DS_RD_SETTING .full_timeout ,
459488 )
460- result = ws .run (env = env , entry = f"python main .py --cache-buster= { time . time () } " )
489+ result = ws .run (env = env , entry = f"python reference_code .py" )
461490 stdout = re .sub (r"^chmod:.*\n?" , "" , result .get_truncated_stdout (), flags = re .MULTILINE )
462491 if result .exit_code == 0 :
463492 # move submission.csv to mock_folder
@@ -579,8 +608,11 @@ def check_hit(selected_exp: DSExperiment, trace: Trace, sota_result: Dict[str, A
579608 # Check by loop_id if available
580609 if hasattr (trace , "idx2loop_id" ):
581610 loop_id = trace .idx2loop_id .get (index )
582- if loop_id and loop_id in sota_result .get ("medal_loops" , []):
583- return True
611+ if loop_id :
612+ if loop_id in sota_result .get ("medal_loops" , []):
613+ return True
614+ return False
615+
584616 # Fallback to checking by index
585617 if index in sota_result .get ("medal_loops_index" , []):
586618 return True
@@ -594,6 +626,11 @@ def try_get_loop_id(trace: Trace, exp: DSExperiment):
594626 return index
595627
596628
629+ def extract_tar (tar_path : str , to_dir : str = "log" ) -> str :
630+ with tarfile .open (tar_path , mode = "r:*" ) as tar :
631+ tar .extractall (path = to_dir )
632+
633+
597634# ==============================================================================
598635# ## Main Orchestration Logic
599636# ==============================================================================
@@ -609,13 +646,14 @@ def evaluate_one_trace(
609646 experiment : str = "validation" ,
610647 log_path : Path | None = None ,
611648 sample_rate : float = 0.8 ,
612- ) -> Tuple [str , bool ]:
649+ ) -> Tuple [str , bool , str ]:
613650 """
614651 Loads a single trace, uses the specified selector to pick an experiment,
615652 and checks if the selection was a "hit" (a known SOTA solution).
616653 """
617654 competition = trace .scen .competition
618655 hit = False
656+ sota_exp_stat = ""
619657
620658 # Example of scenario-specific adjustment
621659 if competition == "detecting-insults-in-social-commentary" :
@@ -635,7 +673,7 @@ def evaluate_one_trace(
635673 if selector_name == "validation" :
636674 if not Path (f"{ DS_RD_SETTING .local_data_path } /{ competition } " ).exists ():
637675 logger .warning (f"Competition { DS_RD_SETTING .local_data_path } /{ competition } does not exist, skipping." )
638- return competition , False
676+ return competition , False , sota_exp_stat
639677 # The ValidationSelector is used to select the best re-test score.
640678 quick_selector = BestValidSelector (num_candidates = 1 , use_decision = True , each_trace = False )
641679 quick_selected_exps = quick_selector .get_sota_exp_to_submit (trace )
@@ -647,14 +685,25 @@ def evaluate_one_trace(
647685 candidate_exps = base_selector .collect_sota_candidates (trace )
648686 if not candidate_exps :
649687 logger .info ("ValidationSelector: Base selector returned no candidates." )
650- return competition , False
688+ return competition , False , sota_exp_stat
651689
652690 logger .info (f"ValidationSelector: Received { len (candidate_exps )} candidates for validation." )
691+ pool_hit = False
653692 if debug :
654693 pool_hit = any (check_hit (candidate_exp , trace , sota_result ) for candidate_exp in candidate_exps )
655- if not pool_hit :
656- logger .info ("ValidationSelector: Base selector's candidates did not hit any SOTA. Skipping validation." )
657- return competition , False
694+ else :
695+ for exp in candidate_exps :
696+ loop_id = try_get_loop_id (trace , exp )
697+ sota_mle_score_paths = [i for i in log_path .rglob (f"Loop_{ loop_id } /running/mle_score/**/*.pkl" )]
698+ if len (sota_mle_score_paths ):
699+ with sota_mle_score_paths [0 ].open ("rb" ) as f :
700+ sota_mle_score = extract_json (pickle .load (f ))
701+ if sota_mle_score .get ("any_medal" , False ):
702+ pool_hit = True
703+ break
704+ if not pool_hit :
705+ logger .info ("ValidationSelector: Selector's candidates did not hit any medal. Skipping validation." )
706+ return competition , False , sota_exp_stat
658707
659708 selector = ValidationSelector (
660709 candidate = [(exp , try_get_loop_id (trace , exp )) for exp in candidate_exps ],
@@ -682,7 +731,14 @@ def evaluate_one_trace(
682731 with sota_mle_score_paths [0 ].open ("rb" ) as f :
683732 sota_mle_score = extract_json (pickle .load (f ))
684733 hit = sota_mle_score .get ("any_medal" , False )
685- return competition , hit
734+ if hit :
735+ if sota_mle_score ["gold_medal" ]:
736+ sota_exp_stat = "gold"
737+ elif sota_mle_score ["silver_medal" ]:
738+ sota_exp_stat = "silver"
739+ elif sota_mle_score ["bronze_medal" ]:
740+ sota_exp_stat = "bronze"
741+ return competition , hit , sota_exp_stat
686742
687743
688744def select_on_existing_trace (
@@ -712,11 +768,21 @@ def select_on_existing_trace(
712768
713769 # Prepare list of tasks for multiprocessing
714770 tasks = []
771+ if debug and experiment and "yaml" in trace_root :
772+ job_info = yaml .safe_load (open (str (Path (trace_root ) / f"{ experiment } .yaml" ), "r" ))
773+ if not competition :
774+ competition = os .getenv ("DS_COMPETITION" )
775+ for job in job_info :
776+ if job ["submit_args" ]["env" ]["DS_COMPETITION" ] == competition :
777+ tar_file = Path ("/mnt/output" ) / job ["results_dir" ] / job ["submit_args" ]["env" ]["RD_RES_NAME" ]
778+ extract_tar (tar_file )
779+ debug = False
780+
715781 if debug :
716782 for trace_folder in trace_root_path .iterdir ():
717783 if not trace_folder .is_dir ():
718784 continue
719- if experiment is not None :
785+ if isinstance ( experiment , str ) and experiment :
720786 if trace_folder .name not in experiment :
721787 continue
722788 for trace_pkl_path in trace_folder .glob ("*.pkl" ):
@@ -779,13 +845,15 @@ def select_on_existing_trace(
779845 hit_list = multiprocessing_wrapper (tasks , n = 1 ) # n=1 for sequential debugging, increase for parallel runs
780846
781847 # Aggregate and report results
782- hit_count = sum (hit for _ , hit in hit_list if hit is not None )
848+ hit_count = sum (hit for _ , hit , _ in hit_list if hit is not None )
783849 total_valid_traces = len (hit_list )
784850
785851 print ("\n " + "=" * 50 )
786852 print (f"Evaluation Summary for Selector: '{ selector_name } '" )
787853 print (f"Total Traces Processed: { total_valid_traces } " )
788854 print (f"Total Hits: { hit_count } " )
855+ if not debug and hit_count :
856+ print (f"Medal info: { hit_list [0 ][2 ]} " )
789857 if total_valid_traces > 0 :
790858 hit_rate = (hit_count / total_valid_traces ) * 100
791859 print (f"Hit Rate: { hit_rate :.2f} %" )
@@ -796,11 +864,13 @@ def select_on_existing_trace(
796864 "total" : total_valid_traces ,
797865 "hit_rate" : hit_rate if total_valid_traces > 0 else 0 ,
798866 }
799- result_dict ["details" ] = [{comp : hit } for comp , hit in hit_list ]
867+ result_dict ["details" ] = [{comp : hit } for comp , hit , _ in hit_list ]
800868
801869 with open (f"result_{ selector_name } .json" , "w" ) as f :
802870 json .dump (result_dict , f , indent = 4 )
803871 logger .info (f"Results saved to result_{ selector_name } .json" )
872+ if "yaml" in trace_root and Path ("log/log" ).exists ():
873+ shutil .rmtree ("log/log" )
804874
805875
806876if __name__ == "__main__" :
0 commit comments