@@ -611,6 +611,7 @@ module Curriculum = struct
611611 window_size;
612612 recent_rewards = ref [] ;
613613 }
614+
614615 let trim_to n lst =
615616 let rec aux idx acc = function
616617 | [] -> List. rev acc
@@ -632,7 +633,7 @@ module Curriculum = struct
632633 trim_to config.window_size (outcome :: ! (config.recent_rewards));
633634 let count = List. length ! (config.recent_rewards) in
634635 let minimum_samples = max 10 (config.window_size / 2 ) in
635- if count > = minimum_samples then (
636+ if count > = minimum_samples then
636637 let wins = List. filter (fun r -> r > 0.5 ) ! (config.recent_rewards) in
637638 let success_rate =
638639 float_of_int (List. length wins) /. float_of_int count
@@ -645,7 +646,7 @@ module Curriculum = struct
645646 incr config.current_idx;
646647 config.recent_rewards := [] ;
647648 true )
648- else false )
649+ else false
649650 else false
650651
651652 let get_current_stage config = List. nth config.stages ! (config.current_idx)
@@ -670,6 +671,7 @@ type state = {
670671
671672module Env_table = Hashtbl. Make (struct
672673 type t = Obj .t
674+
673675 let equal a b = a == b
674676 let hash = Hashtbl. hash
675677end )
@@ -689,7 +691,6 @@ let lookup_state env =
689691let max_grid_size = 10
690692let observation_channels = 8
691693let observation_flat_size = observation_channels * max_grid_size * max_grid_size
692-
693694let mask_channel_index = observation_channels - 1
694695
695696let cell_to_channel = function
@@ -700,6 +701,7 @@ let cell_to_channel = function
700701 | Core. Box_on_target -> 4
701702 | Core. Player -> 5
702703 | Core. Player_on_target -> 6
704+
703705let render_text state = Core. render state.game_state
704706
705707let tile_color = function
@@ -826,8 +828,7 @@ let action_mask state =
826828 can_move Core. Right state;
827829 |]
828830
829- let manhattan (x1 , y1 ) (x2 , y2 ) =
830- Stdlib. abs (x1 - x2) + Stdlib. abs (y1 - y2)
831+ let manhattan (x1 , y1 ) (x2 , y2 ) = Stdlib. abs (x1 - x2) + Stdlib. abs (y1 - y2)
831832
832833let boxes_and_targets state =
833834 let open Core in
@@ -836,15 +837,12 @@ let boxes_and_targets state =
836837 for y = 0 to state.height - 1 do
837838 for x = 0 to state.width - 1 do
838839 match state.grid.(y).(x) with
839- | Box ->
840- boxes := (x, y) :: ! boxes
840+ | Box -> boxes := (x, y) :: ! boxes
841841 | Box_on_target ->
842842 boxes := (x, y) :: ! boxes;
843843 targets := (x, y) :: ! targets
844- | Target ->
845- targets := (x, y) :: ! targets
846- | Player_on_target ->
847- targets := (x, y) :: ! targets
844+ | Target -> targets := (x, y) :: ! targets
845+ | Player_on_target -> targets := (x, y) :: ! targets
848846 | _ -> ()
849847 done
850848 done ;
@@ -856,7 +854,7 @@ let sorted_boxes state =
856854
857855let potential state =
858856 let boxes, targets = boxes_and_targets state in
859- match boxes, targets with
857+ match ( boxes, targets) with
860858 | [] , _ | _ , [] -> 0.0
861859 | _ ->
862860 let best_distance (x , y ) =
@@ -867,9 +865,7 @@ let potential state =
867865 let total =
868866 List. fold_left (fun acc box -> acc + best_distance box) 0 boxes
869867 in
870- let max_per_box =
871- max 1 ((state.width - 1 ) + (state.height - 1 ))
872- in
868+ let max_per_box = max 1 (state.width - 1 + (state.height - 1 )) in
873869 let max_total = max_per_box * List. length boxes in
874870 let diff = max_total - total in
875871 float_of_int (if diff > 0 then diff else 0 )
@@ -904,9 +900,7 @@ let stage_info curriculum_config =
904900 | None -> " 1/1"
905901
906902let registered_curriculum env =
907- try
908- (lookup_state env).curriculum_config
909- with Invalid_argument _ -> None
903+ try (lookup_state env).curriculum_config with Invalid_argument _ -> None
910904
911905let stage_to_string = function
912906 | Curriculum. Corridor len -> Printf. sprintf " corridor-%d" len
@@ -923,12 +917,8 @@ let stage_descriptor curriculum_config =
923917 Printf. sprintf " %s-%02d-of-%02d" (stage_to_string stage) (idx + 1 ) total
924918 | None -> " single-stage"
925919
926- let current_game_state env =
927- Core. copy_state (lookup_state env).game_state
928-
929- let has_registered_state env =
930- Env_table. mem state_registry (Obj. repr env)
931-
920+ let current_game_state env = Core. copy_state (lookup_state env).game_state
921+ let has_registered_state env = Env_table. mem state_registry (Obj. repr env)
932922let state_opt env = Env_table. find_opt state_registry (Obj. repr env)
933923
934924let current_stage env =
@@ -940,8 +930,7 @@ let current_stage env =
940930 Some (stage, idx, List. length config.Curriculum. stages)
941931 | None -> None
942932
943- let current_stage_label env =
944- stage_info (registered_curriculum env)
933+ let current_stage_label env = stage_info (registered_curriculum env)
945934
946935let current_stage_descriptor env =
947936 match state_opt env with
@@ -951,8 +940,8 @@ let current_stage_descriptor env =
951940let current_stage_descriptor_opt env =
952941 match state_opt env with
953942 | None -> None
954- | Some state ->
955- ( match state.curriculum_config with
943+ | Some state -> (
944+ match state.curriculum_config with
956945 | None -> None
957946 | Some _ -> Some (stage_descriptor state.curriculum_config))
958947
@@ -993,8 +982,7 @@ let reset _env ?options:_ () state =
993982 let obs = make_observation level in
994983 let info =
995984 Info. empty
996- |> Info. set " stage"
997- (Info. string (stage_descriptor state.curriculum_config))
985+ |> Info. set " stage" (Info. string (stage_descriptor state.curriculum_config))
998986 |> Info. set " action_mask" (Info. bool_array (action_mask level))
999987 in
1000988 (obs, info)
@@ -1005,17 +993,15 @@ let step _env action state =
1005993 let boxes_before = sorted_boxes state.game_state in
1006994 let direction = action_to_direction action in
1007995 let new_state = Core. apply_action state.game_state direction in
1008- let moved = not (Stdlib. (== ) new_state state.game_state) in
996+ let moved = not (Stdlib. ( == ) new_state state.game_state) in
1009997 state.game_state < - new_state;
1010998 let boxes_after = sorted_boxes state.game_state in
1011999 let pushed = boxes_before <> boxes_after in
10121000 let won = Core. check_win state.game_state in
10131001 let no_moves = not (has_any_move state.game_state) in
10141002 let truncated = state.steps > = state.max_steps || no_moves in
10151003
1016- let base_reward =
1017- if won then 100.0 else if not moved then - 0.2 else - 0.01
1018- in
1004+ let base_reward = if won then 100.0 else if not moved then - 0.2 else - 0.01 in
10191005 let phi_s' = potential state.game_state in
10201006 let shaping =
10211007 if pushed then shaping_beta *. ((shaping_gamma *. phi_s') -. phi_s) else 0.0
0 commit comments