Skip to content

Commit 4ec76f3

Browse files
committed
chore: Format code
1 parent b2e68e1 commit 4ec76f3

File tree

7 files changed

+106
-168
lines changed

7 files changed

+106
-168
lines changed

dev/makemore/makemore.ml

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ let train_transformer ~vocab_size ~block_size ~n_layer ~n_head ~n_embd ~lr
290290
Ptree.dict
291291
[
292292
( "attn",
293-
Kaun.Attention.Multi_head.init attn_config ~rngs:keys.(0)
294-
~dtype );
293+
Kaun.Attention.Multi_head.init attn_config ~rngs:keys.(0) ~dtype
294+
);
295295
("ln1", ln1.init ~rngs:keys.(1) ~dtype);
296296
("ln2", ln2.init ~rngs:keys.(2) ~dtype);
297297
("ff", ff.init ~rngs:keys.(3) ~dtype);
@@ -301,8 +301,7 @@ let train_transformer ~vocab_size ~block_size ~n_layer ~n_head ~n_embd ~lr
301301
let fields =
302302
match params with
303303
| Ptree.Dict fields -> fields
304-
| _ ->
305-
failwith "transformer_decoder_block: params must be a dict"
304+
| _ -> failwith "transformer_decoder_block: params must be a dict"
306305
in
307306
let find name =
308307
match List.assoc_opt name fields with
@@ -321,12 +320,8 @@ let train_transformer ~vocab_size ~block_size ~n_layer ~n_head ~n_embd ~lr
321320
let positions =
322321
Rune.arange Rune.int32 0 seq_len 1 |> Rune.reshape [| 1; seq_len |]
323322
in
324-
let query_idx =
325-
Rune.reshape [| 1; seq_len; 1 |] positions
326-
in
327-
let key_idx =
328-
Rune.reshape [| 1; 1; seq_len |] positions
329-
in
323+
let query_idx = Rune.reshape [| 1; seq_len; 1 |] positions in
324+
let key_idx = Rune.reshape [| 1; 1; seq_len |] positions in
330325
let base_mask = Rune.less_equal key_idx query_idx in
331326
let attention_mask =
332327
if batch = 1 then base_mask

fehu/examples/05-sokoban/reinforce_sokoban.ml

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ let add_channel_dim ~n_channels () =
4040
(fun _ ~training:_ ?rngs:_ x ->
4141
let shape = Rune.shape x in
4242
let batch =
43-
match Array.length shape with
44-
| 1 -> 1
45-
| 0 -> 1
46-
| _ -> shape.(0)
43+
match Array.length shape with 1 -> 1 | 0 -> 1 | _ -> shape.(0)
4744
in
4845
Rune.reshape [| batch; n_channels; grid_size; grid_size |] x);
4946
}
@@ -94,9 +91,11 @@ let apply_action_mask logits = function
9491
let len = Array.length mask in
9592
let mask_offsets =
9693
Array.init n_actions (fun idx ->
97-
if idx < len && mask.(idx) then 0.0 else -.1e9)
94+
if idx < len && mask.(idx) then 0.0 else -1e9)
95+
in
96+
let mask_tensor =
97+
Rune.create Rune.float32 [| 1; n_actions |] mask_offsets
9898
in
99-
let mask_tensor = Rune.create Rune.float32 [| 1; n_actions |] mask_offsets in
10099
Rune.add logits mask_tensor
101100
| None -> logits
102101

@@ -133,7 +132,8 @@ let record_random_rollout ~path ~max_steps =
133132
let record_trained_rollout ~level ~path ~max_steps ~policy_net ~params =
134133
let env =
135134
Sokoban_env.sokoban ~render_mode:`Rgb_array ~max_steps
136-
~initial_state:(Sokoban_env.Core.copy_state level) ()
135+
~initial_state:(Sokoban_env.Core.copy_state level)
136+
()
137137
in
138138
let policy =
139139
Policy.deterministic (fun obs ->
@@ -265,18 +265,19 @@ let train ?record_dir env config =
265265
in
266266

267267
Printf.printf
268-
"Episode %d (Stage %s): Avg Reward = %.2f, Win Rate = %.1f%% (%.1f%%), Length = \
269-
%d\n\
268+
"Episode %d (Stage %s): Avg Reward = %.2f, Win Rate = %.1f%% \
269+
(%.1f%%), Length = %d\n\
270270
%!"
271271
metrics.total_episodes stage_desc avg_reward recent_win_rate
272-
(float_of_int !total_wins /. float_of_int metrics.total_episodes
272+
(float_of_int !total_wins
273+
/. float_of_int metrics.total_episodes
273274
*. 100.0)
274275
metrics.episode_length;
275276
Printf.printf
276277
" Entropy = %.3f, Log Prob = %.3f, Adv Mean = %.3f, Adv \
277278
Std = %.3f"
278-
metrics.avg_entropy metrics.avg_log_prob
279-
metrics.adv_mean metrics.adv_std;
279+
metrics.avg_entropy metrics.avg_log_prob metrics.adv_mean
280+
metrics.adv_std;
280281
(match metrics.value_loss with
281282
| Some v -> Printf.printf ", Value Loss = %.3f" v
282283
| None -> ());
@@ -297,13 +298,11 @@ let train ?record_dir env config =
297298
(Printf.sprintf "sokoban_train_ep%04d_%s.mp4"
298299
metrics.total_episodes stage_desc)
299300
in
300-
Printf.printf
301-
"Recording rollout at episode %d (Stage %s) to %s\n%!"
301+
Printf.printf "Recording rollout at episode %d (Stage %s) to %s\n%!"
302302
metrics.total_episodes stage_desc path;
303303
record_guard "recording training rollout" (fun () ->
304304
record_trained_rollout ~level ~path ~max_steps:config.max_steps
305-
~policy_net
306-
~params:!params_ref)))
305+
~policy_net ~params:!params_ref)))
307306
record_dir);
308307

309308
if
@@ -398,18 +397,18 @@ let () =
398397
Filename.concat dir
399398
(Printf.sprintf "sokoban_trained_%s.mp4" final_stage_desc)
400399
in
401-
Printf.printf "Recording trained rollout (%s) to %s\n%!"
402-
final_stage_desc trained_path;
400+
Printf.printf "Recording trained rollout (%s) to %s\n%!" final_stage_desc
401+
trained_path;
403402
record_guard "recording trained rollout" (fun () ->
404-
record_trained_rollout ~level:final_level ~path:trained_path
403+
record_trained_rollout ~level:final_level ~path:trained_path
405404
~max_steps:config.max_steps ~policy_net ~params))
406405
record_dir;
407406

408407
(* Compare with random policy *)
409-
Printf.printf "\nEvaluating random policy on stage %s...\n%!"
410-
final_stage_desc;
408+
Printf.printf "\nEvaluating random policy on stage %s...\n%!" final_stage_desc;
411409
let random_env =
412-
Sokoban_env.sokoban ~max_steps:config.max_steps ~initial_state:final_level ()
410+
Sokoban_env.sokoban ~max_steps:config.max_steps ~initial_state:final_level
411+
()
413412
in
414413
let random_policy = Policy.random random_env in
415414
let random_stats =

fehu/examples/05-sokoban/sokoban_env.ml

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ module Curriculum = struct
611611
window_size;
612612
recent_rewards = ref [];
613613
}
614+
614615
let trim_to n lst =
615616
let rec aux idx acc = function
616617
| [] -> List.rev acc
@@ -632,7 +633,7 @@ module Curriculum = struct
632633
trim_to config.window_size (outcome :: !(config.recent_rewards));
633634
let count = List.length !(config.recent_rewards) in
634635
let minimum_samples = max 10 (config.window_size / 2) in
635-
if count >= minimum_samples then (
636+
if count >= minimum_samples then
636637
let wins = List.filter (fun r -> r > 0.5) !(config.recent_rewards) in
637638
let success_rate =
638639
float_of_int (List.length wins) /. float_of_int count
@@ -645,7 +646,7 @@ module Curriculum = struct
645646
incr config.current_idx;
646647
config.recent_rewards := [];
647648
true)
648-
else false)
649+
else false
649650
else false
650651

651652
let get_current_stage config = List.nth config.stages !(config.current_idx)
@@ -670,6 +671,7 @@ type state = {
670671

671672
module Env_table = Hashtbl.Make (struct
672673
type t = Obj.t
674+
673675
let equal a b = a == b
674676
let hash = Hashtbl.hash
675677
end)
@@ -689,7 +691,6 @@ let lookup_state env =
689691
let max_grid_size = 10
690692
let observation_channels = 8
691693
let observation_flat_size = observation_channels * max_grid_size * max_grid_size
692-
693694
let mask_channel_index = observation_channels - 1
694695

695696
let cell_to_channel = function
@@ -700,6 +701,7 @@ let cell_to_channel = function
700701
| Core.Box_on_target -> 4
701702
| Core.Player -> 5
702703
| Core.Player_on_target -> 6
704+
703705
let render_text state = Core.render state.game_state
704706

705707
let tile_color = function
@@ -826,8 +828,7 @@ let action_mask state =
826828
can_move Core.Right state;
827829
|]
828830

829-
let manhattan (x1, y1) (x2, y2) =
830-
Stdlib.abs (x1 - x2) + Stdlib.abs (y1 - y2)
831+
let manhattan (x1, y1) (x2, y2) = Stdlib.abs (x1 - x2) + Stdlib.abs (y1 - y2)
831832

832833
let boxes_and_targets state =
833834
let open Core in
@@ -836,15 +837,12 @@ let boxes_and_targets state =
836837
for y = 0 to state.height - 1 do
837838
for x = 0 to state.width - 1 do
838839
match state.grid.(y).(x) with
839-
| Box ->
840-
boxes := (x, y) :: !boxes
840+
| Box -> boxes := (x, y) :: !boxes
841841
| Box_on_target ->
842842
boxes := (x, y) :: !boxes;
843843
targets := (x, y) :: !targets
844-
| Target ->
845-
targets := (x, y) :: !targets
846-
| Player_on_target ->
847-
targets := (x, y) :: !targets
844+
| Target -> targets := (x, y) :: !targets
845+
| Player_on_target -> targets := (x, y) :: !targets
848846
| _ -> ()
849847
done
850848
done;
@@ -856,7 +854,7 @@ let sorted_boxes state =
856854

857855
let potential state =
858856
let boxes, targets = boxes_and_targets state in
859-
match boxes, targets with
857+
match (boxes, targets) with
860858
| [], _ | _, [] -> 0.0
861859
| _ ->
862860
let best_distance (x, y) =
@@ -867,9 +865,7 @@ let potential state =
867865
let total =
868866
List.fold_left (fun acc box -> acc + best_distance box) 0 boxes
869867
in
870-
let max_per_box =
871-
max 1 ((state.width - 1) + (state.height - 1))
872-
in
868+
let max_per_box = max 1 (state.width - 1 + (state.height - 1)) in
873869
let max_total = max_per_box * List.length boxes in
874870
let diff = max_total - total in
875871
float_of_int (if diff > 0 then diff else 0)
@@ -904,9 +900,7 @@ let stage_info curriculum_config =
904900
| None -> "1/1"
905901

906902
let registered_curriculum env =
907-
try
908-
(lookup_state env).curriculum_config
909-
with Invalid_argument _ -> None
903+
try (lookup_state env).curriculum_config with Invalid_argument _ -> None
910904

911905
let stage_to_string = function
912906
| Curriculum.Corridor len -> Printf.sprintf "corridor-%d" len
@@ -923,12 +917,8 @@ let stage_descriptor curriculum_config =
923917
Printf.sprintf "%s-%02d-of-%02d" (stage_to_string stage) (idx + 1) total
924918
| None -> "single-stage"
925919

926-
let current_game_state env =
927-
Core.copy_state (lookup_state env).game_state
928-
929-
let has_registered_state env =
930-
Env_table.mem state_registry (Obj.repr env)
931-
920+
let current_game_state env = Core.copy_state (lookup_state env).game_state
921+
let has_registered_state env = Env_table.mem state_registry (Obj.repr env)
932922
let state_opt env = Env_table.find_opt state_registry (Obj.repr env)
933923

934924
let current_stage env =
@@ -940,8 +930,7 @@ let current_stage env =
940930
Some (stage, idx, List.length config.Curriculum.stages)
941931
| None -> None
942932

943-
let current_stage_label env =
944-
stage_info (registered_curriculum env)
933+
let current_stage_label env = stage_info (registered_curriculum env)
945934

946935
let current_stage_descriptor env =
947936
match state_opt env with
@@ -951,8 +940,8 @@ let current_stage_descriptor env =
951940
let current_stage_descriptor_opt env =
952941
match state_opt env with
953942
| None -> None
954-
| Some state ->
955-
(match state.curriculum_config with
943+
| Some state -> (
944+
match state.curriculum_config with
956945
| None -> None
957946
| Some _ -> Some (stage_descriptor state.curriculum_config))
958947

@@ -993,8 +982,7 @@ let reset _env ?options:_ () state =
993982
let obs = make_observation level in
994983
let info =
995984
Info.empty
996-
|> Info.set "stage"
997-
(Info.string (stage_descriptor state.curriculum_config))
985+
|> Info.set "stage" (Info.string (stage_descriptor state.curriculum_config))
998986
|> Info.set "action_mask" (Info.bool_array (action_mask level))
999987
in
1000988
(obs, info)
@@ -1005,17 +993,15 @@ let step _env action state =
1005993
let boxes_before = sorted_boxes state.game_state in
1006994
let direction = action_to_direction action in
1007995
let new_state = Core.apply_action state.game_state direction in
1008-
let moved = not (Stdlib.(==) new_state state.game_state) in
996+
let moved = not (Stdlib.( == ) new_state state.game_state) in
1009997
state.game_state <- new_state;
1010998
let boxes_after = sorted_boxes state.game_state in
1011999
let pushed = boxes_before <> boxes_after in
10121000
let won = Core.check_win state.game_state in
10131001
let no_moves = not (has_any_move state.game_state) in
10141002
let truncated = state.steps >= state.max_steps || no_moves in
10151003

1016-
let base_reward =
1017-
if won then 100.0 else if not moved then -0.2 else -0.01
1018-
in
1004+
let base_reward = if won then 100.0 else if not moved then -0.2 else -0.01 in
10191005
let phi_s' = potential state.game_state in
10201006
let shaping =
10211007
if pushed then shaping_beta *. ((shaping_gamma *. phi_s') -. phi_s) else 0.0

0 commit comments

Comments
 (0)