Skip to content

Commit c73f67a

Browse files
jingyuanlmxuangu-fangyou-n-g
authored
fix: fix mcts (#1270)
* init mcts class * full ver of MCTS * auto-lint * make MCTS feedback in exp-gen() * refactor: move reset logic from Trace to ExpGen and update usage accordingly * fix: reinitialize trace on consecutive errors in DataScienceRDLoop * feat: add reset method to BaseScheduler and call in MCTSScheduler reset * style: reorder imports for consistency and PEP8 compliance * lint * fix observe_feedback * fix bug * remove uncommited_rec_status * more simple * refactor: move commit observation logic to process_uncommitted_nodes method * docs: add TODO comment about rule-based virtual root node expansion * add score reward * fix bug * fix small bug * lint * change reward * small small change * autolint --------- Co-authored-by: xuangu-fang <[email protected]> Co-authored-by: Young <[email protected]>
1 parent 6f86863 commit c73f67a

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def select(self, trace: DSTrace) -> tuple[int, ...] | None:
393393

394394
return (best_leaf,)
395395

396-
def observe_feedback(self, trace: DSTrace, new_idx: int, reward: float | None = None) -> None:
396+
def observe_feedback(self, trace: DSTrace, new_idx: int) -> None:
397397
"""
398398
Update statistics after an experiment is committed to the trace.
399399
@@ -402,21 +402,16 @@ def observe_feedback(self, trace: DSTrace, new_idx: int, reward: float | None =
402402
new_idx: Index of the newly appended experiment in trace.hist.
403403
reward: Optional explicit reward. If None, derive from feedback.decision (1.0/0.0).
404404
"""
405-
if reward is None:
406-
if 0 <= new_idx < len(trace.hist):
407-
re, fb = trace.hist[new_idx]
408-
if DS_RD_SETTING.enable_score_reward:
409-
bigger_is_better = get_metric_direction(trace.scen.competition)
410-
if getattr(fb, "decision", False):
411-
reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1)
412-
else:
413-
reward = -1 if bigger_is_better else 1
414-
else:
415-
reward = 1.0 if getattr(fb, "decision", False) else 0.0
416-
else:
417-
# Out-of-range safety
418-
reward = 0.0
419405

406+
re, fb = trace.hist[new_idx]
407+
if DS_RD_SETTING.enable_score_reward:
408+
bigger_is_better = get_metric_direction(trace.scen.competition)
409+
if getattr(fb, "decision", False):
410+
reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1)
411+
else:
412+
reward = -1 if bigger_is_better else 1
413+
else:
414+
reward = 1.0 if getattr(fb, "decision", False) else 0.0
420415
id_list = trace.get_parents(new_idx)
421416
for id in id_list:
422417
self.node_value_sum[id] = self.node_value_sum.get(id, 0.0) + float(reward)

0 commit comments

Comments
 (0)