Skip to content

Commit d2e6c0c

Browse files
authored
Moved ITE and loops treatment to cps_conversion (#424)
1 parent 1883839 commit d2e6c0c

File tree

5 files changed

+117
-111
lines changed

5 files changed

+117
-111
lines changed

middle_end/flambda/flambda_middle_end.ml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,14 @@ let middle_end0 ppf ~prefixname ~backend ~filename ~module_ident
113113
~module_block_size_in_words ~module_initializer =
114114
Misc.Color.setup !Clflags.color;
115115
Profile.record_call "flambda.0" (fun () ->
116-
let prepared_lambda, recursive_static_catches =
116+
let prepared_lambda =
117117
Profile.record_call "prepare_lambda" (fun () ->
118118
Prepare_lambda.run module_initializer)
119119
in
120120
print_prepared_lambda ppf prepared_lambda;
121121
let ilambda =
122122
Profile.record_call "cps_conversion" (fun () ->
123-
Cps_conversion.lambda_to_ilambda prepared_lambda
124-
~recursive_static_catches)
123+
Cps_conversion.lambda_to_ilambda prepared_lambda)
125124
in
126125
print_ilambda ppf ilambda;
127126
let ilambda =

middle_end/flambda/from_lambda/cps_conversion.ml

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ let try_stack_at_handler = ref Continuation.Map.empty
5858
let recursive_static_catches = ref Numbers.Int.Set.empty
5959
let mutable_variables = ref Ident.Set.empty
6060

61+
let mark_as_recursive_static_catch cont =
62+
if Numbers.Int.Set.mem cont !recursive_static_catches then begin
63+
Misc.fatal_errorf "Static catch with continuation %d already marked as \
64+
recursive -- is it being redefined?"
65+
cont
66+
end;
67+
recursive_static_catches := Numbers.Int.Set.add cont !recursive_static_catches
68+
6169
let _print_stack ppf stack =
6270
Format.fprintf ppf "%a"
6371
(Format.pp_print_list ~pp_sep:(fun ppf () -> Format.fprintf ppf "; ")
@@ -116,7 +124,7 @@ let compile_staticfail ~(continuation : Continuation.t) ~args =
116124
in
117125
mk_poptraps (I.Apply_cont (continuation, None, args))
118126

119-
let switch_for_if_then_else ~cond ~ifso ~ifnot k =
127+
let switch_for_if_then_else ~cond ~ifso ~ifnot =
120128
(* CR mshinwell: We need to make sure that [cond] is {0, 1}-valued.
121129
The frontend should have been fixed on this branch for this. *)
122130
let switch : Lambda.lambda_switch =
@@ -128,7 +136,7 @@ let switch_for_if_then_else ~cond ~ifso ~ifnot k =
128136
sw_tags_to_sizes = Tag.Scannable.Map.empty;
129137
}
130138
in
131-
k (L.Lswitch (cond, switch, Loc_unknown))
139+
L.Lswitch (cond, switch, Loc_unknown)
132140

133141
let transform_primitive (prim : L.primitive) args loc =
134142
match prim, args with
@@ -141,8 +149,7 @@ let transform_primitive (prim : L.primitive) args loc =
141149
switch_for_if_then_else
142150
~cond:(L.Lvar cond)
143151
~ifso:(L.Lvar const_true)
144-
~ifnot:arg2
145-
(fun lam -> lam)))))
152+
~ifnot:arg2))))
146153
| Psequand, [arg1; arg2] ->
147154
let const_false = Ident.create_local "const_false" in
148155
let cond = Ident.create_local "cond_sequand" in
@@ -152,8 +159,7 @@ let transform_primitive (prim : L.primitive) args loc =
152159
switch_for_if_then_else
153160
~cond:(L.Lvar cond)
154161
~ifso:arg2
155-
~ifnot:(L.Lvar const_false)
156-
(fun lam -> lam)))))
162+
~ifnot:(L.Lvar const_false)))))
157163
| (Psequand | Psequor), _ ->
158164
Misc.fatal_error "Psequand / Psequor must have exactly two arguments"
159165
(* Removed. Should be safe, but will no longer catch misuses.
@@ -165,8 +171,7 @@ let transform_primitive (prim : L.primitive) args loc =
165171
(switch_for_if_then_else
166172
~cond:(L.Lprim (Pflambda_isint, [arg], loc))
167173
~ifso:(L.Lconst (Const_base (Const_int 1)))
168-
~ifnot:(L.Lconst (Const_base (Const_int 0)))
169-
(fun lam -> lam))
174+
~ifnot:(L.Lconst (Const_base (Const_int 0))))
170175
| (Pidentity | Pbytes_to_string | Pbytes_of_string), [arg] -> Transformed arg
171176
| Pignore, [arg] ->
172177
let ident = Ident.create_local "ignore" in
@@ -247,6 +252,66 @@ let transform_primitive (prim : L.primitive) args loc =
247252
end
248253
| _, _ -> Primitive (prim, args, loc)
249254

255+
let rec_catch_for_while_loop cond body =
256+
let cont = L.next_raise_count () in
257+
mark_as_recursive_static_catch cont;
258+
let cond_result = Ident.create_local "while_cond_result" in
259+
let lam : L.lambda =
260+
Lstaticcatch (
261+
Lstaticraise (cont, []),
262+
(cont, []),
263+
Llet (Strict, Pgenval, cond_result, cond,
264+
Lifthenelse (Lvar cond_result,
265+
Lsequence (
266+
body,
267+
Lstaticraise (cont, [])),
268+
Lconst (Const_base (Const_int 0)))))
269+
in lam
270+
271+
let rec_catch_for_for_loop
272+
ident start stop (dir : Asttypes.direction_flag) body =
273+
let cont = L.next_raise_count () in
274+
mark_as_recursive_static_catch cont;
275+
let start_ident = Ident.create_local "for_start" in
276+
let stop_ident = Ident.create_local "for_stop" in
277+
let first_test : L.lambda =
278+
match dir with
279+
| Upto ->
280+
Lprim (Pintcomp Cle,
281+
[L.Lvar start_ident; L.Lvar stop_ident],
282+
Loc_unknown)
283+
| Downto ->
284+
Lprim (Pintcomp Cge,
285+
[L.Lvar start_ident; L.Lvar stop_ident],
286+
Loc_unknown)
287+
in
288+
let subsequent_test : L.lambda =
289+
Lprim (Pintcomp Cne, [L.Lvar ident; L.Lvar stop_ident], Loc_unknown)
290+
in
291+
let one : L.lambda = Lconst (Const_base (Const_int 1)) in
292+
let next_value_of_counter =
293+
match dir with
294+
| Upto -> L.Lprim (Paddint, [L.Lvar ident; one], Loc_unknown)
295+
| Downto -> L.Lprim (Psubint, [L.Lvar ident; one], Loc_unknown)
296+
in
297+
let lam : L.lambda =
298+
(* Care needs to be taken here not to cause overflow if, for an
299+
incrementing for-loop, the upper bound is [max_int]; likewise, for
300+
a decrementing for-loop, if the lower bound is [min_int]. *)
301+
Llet (Strict, Pgenval, start_ident, start,
302+
Llet (Strict, Pgenval, stop_ident, stop,
303+
Lifthenelse (first_test,
304+
Lstaticcatch (
305+
Lstaticraise (cont, [L.Lvar start_ident]),
306+
(cont, [ident, Pgenval]),
307+
Lsequence (
308+
body,
309+
Lifthenelse (subsequent_test,
310+
Lstaticraise (cont, [next_value_of_counter]),
311+
L.lambda_unit))),
312+
L.lambda_unit)))
313+
in lam
314+
250315
let rec cps_non_tail (lam : L.lambda) (k : Ident.t -> Ilambda.t)
251316
(k_exn : Continuation.t) : Ilambda.t =
252317
match lam with
@@ -536,13 +601,25 @@ let rec cps_non_tail (lam : L.lambda) (k : Ident.t -> Ilambda.t)
536601
};
537602
handler = k result_var;
538603
}
604+
| Lifthenelse (cond, ifso, ifnot) ->
605+
let lam = switch_for_if_then_else ~cond ~ifso ~ifnot in
606+
cps_non_tail lam k k_exn
607+
| Lsequence (lam1, lam2) ->
608+
let ident = Ident.create_local "sequence" in
609+
cps_non_tail (L.Llet (Strict, Pgenval, ident, lam1, lam2)) k k_exn
610+
| Lwhile (cond, body) ->
611+
let loop = rec_catch_for_while_loop cond body in
612+
cps_non_tail loop k k_exn
613+
| Lfor (ident, start, stop, dir, body) ->
614+
let loop = rec_catch_for_for_loop ident start stop dir body in
615+
cps_non_tail loop k k_exn
539616
| Lassign (being_assigned, new_value) ->
540617
cps_non_tail_simple new_value (fun new_value ->
541618
name_then_cps_non_tail "assign"
542619
(I.Assign { being_assigned; new_value; })
543620
k k_exn)
544621
k_exn
545-
| Lsequence _ | Lifthenelse _ | Lwhile _ | Lfor _ | Lifused _ | Levent _ ->
622+
| Lifused _ | Levent _ ->
546623
Misc.fatal_errorf "Term should have been eliminated by [Prepare_lambda]: %a"
547624
Printlambda.lambda lam
548625

@@ -818,7 +895,19 @@ and cps_tail (lam : L.lambda) (k : Continuation.t) (k_exn : Continuation.t)
818895
};
819896
handler;
820897
}
821-
| Lsequence _ | Lifthenelse _ | Lwhile _ | Lfor _ | Lifused _ | Levent _ ->
898+
| Lifthenelse (cond, ifso, ifnot) ->
899+
let lam = switch_for_if_then_else ~cond ~ifso ~ifnot in
900+
cps_tail lam k k_exn
901+
| Lsequence (lam1, lam2) ->
902+
let ident = Ident.create_local "sequence" in
903+
cps_tail (L.Llet (Strict, Pgenval, ident, lam1, lam2)) k k_exn
904+
| Lwhile (cond, body) ->
905+
let loop = rec_catch_for_while_loop cond body in
906+
cps_tail loop k k_exn
907+
| Lfor (ident, start, stop, dir, body) ->
908+
let loop = rec_catch_for_for_loop ident start stop dir body in
909+
cps_tail loop k k_exn
910+
| Lifused _ | Levent _ ->
822911
Misc.fatal_errorf "Term should have been eliminated by [Prepare_lambda]: %a"
823912
Printlambda.lambda lam
824913

@@ -943,12 +1032,11 @@ and cps_switch (switch : proto_switch) ~scrutinee (k : Continuation.t)
9431032
wrappers)
9441033
k_exn
9451034

946-
let lambda_to_ilambda lam ~recursive_static_catches:recursive_static_catches'
947-
: Ilambda.program =
1035+
let lambda_to_ilambda lam : Ilambda.program =
9481036
static_exn_env := Numbers.Int.Map.empty;
9491037
try_stack := [];
9501038
try_stack_at_handler := Continuation.Map.empty;
951-
recursive_static_catches := recursive_static_catches';
1039+
recursive_static_catches := Numbers.Int.Set.empty;
9521040
mutable_variables := Ident.Set.empty;
9531041
let the_end = Continuation.create ~sort:Define_root_symbol () in
9541042
let the_end_exn = Continuation.create ~sort:Exn () in

middle_end/flambda/from_lambda/cps_conversion.mli

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,4 @@
1818

1919
[@@@ocaml.warning "+a-4-30-40-41-42"]
2020

21-
val lambda_to_ilambda
22-
: Lambda.lambda
23-
-> recursive_static_catches:Numbers.Int.Set.t
24-
-> Ilambda.program
21+
val lambda_to_ilambda : Lambda.lambda -> Ilambda.program

middle_end/flambda/from_lambda/prepare_lambda.ml

Lines changed: 12 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -62,31 +62,6 @@ end = struct
6262
}
6363
end
6464

65-
(* CR-soon mshinwell: Remove mutable state *)
66-
let recursive_static_catches = ref Numbers.Int.Set.empty
67-
68-
let mark_as_recursive_static_catch cont =
69-
if Numbers.Int.Set.mem cont !recursive_static_catches then begin
70-
Misc.fatal_errorf "Static catch with continuation %d already marked as \
71-
recursive -- is it being redefined?"
72-
cont
73-
end;
74-
recursive_static_catches := Numbers.Int.Set.add cont !recursive_static_catches
75-
76-
let switch_for_if_then_else ~cond ~ifso ~ifnot k =
77-
(* CR mshinwell: We need to make sure that [cond] is {0, 1}-valued.
78-
The frontend should have been fixed on this branch for this. *)
79-
let switch : Lambda.lambda_switch =
80-
{ sw_numconsts = 2;
81-
sw_consts = [0, ifnot; 1, ifso];
82-
sw_numblocks = 0;
83-
sw_blocks = [];
84-
sw_failaction = None;
85-
sw_tags_to_sizes = Tag.Scannable.Map.empty;
86-
}
87-
in
88-
k (L.Lswitch (cond, switch, Loc_unknown))
89-
9065
(*
9166
let simplify_primitive (prim : L.primitive) args loc =
9267
match prim, args with
@@ -349,65 +324,20 @@ let rec prepare env (lam : L.lambda) (k : L.lambda -> L.lambda) =
349324
prepare env cond (fun cond ->
350325
prepare env ifso (fun ifso ->
351326
prepare env ifnot (fun ifnot ->
352-
switch_for_if_then_else ~cond ~ifso ~ifnot k)))
327+
k (L.Lifthenelse(cond, ifso, ifnot)))))
353328
| Lsequence (lam1, lam2) ->
354-
let ident = Ident.create_local "sequence" in
355-
prepare env (L.Llet (Strict, Pgenval, ident, lam1, lam2)) k
329+
prepare env lam1 (fun lam1 ->
330+
prepare env lam2 (fun lam2 ->
331+
k (L.Lsequence(lam1, lam2))))
356332
| Lwhile (cond, body) ->
357-
let cont = L.next_raise_count () in
358-
mark_as_recursive_static_catch cont;
359-
let cond_result = Ident.create_local "cond_result" in
360-
let lam : L.lambda =
361-
Lstaticcatch (
362-
Lstaticraise (cont, []),
363-
(cont, []),
364-
Llet (Strict, Pgenval, cond_result, cond,
365-
Lifthenelse (Lvar cond_result,
366-
Lsequence (
367-
body,
368-
Lstaticraise (cont, [])),
369-
Lconst (Const_base (Const_int 0)))))
370-
in
371-
prepare env lam k
333+
prepare env cond (fun cond ->
334+
prepare env body (fun body ->
335+
k (Lwhile (cond, body))))
372336
| Lfor (ident, start, stop, dir, body) ->
373-
let cont = L.next_raise_count () in
374-
mark_as_recursive_static_catch cont;
375-
let start_ident = Ident.create_local "start" in
376-
let stop_ident = Ident.create_local "stop" in
377-
let first_test : L.lambda =
378-
match dir with
379-
| Upto ->
380-
Lprim (Pintcomp Cle, [L.Lvar start_ident; L.Lvar stop_ident], Loc_unknown)
381-
| Downto ->
382-
Lprim (Pintcomp Cge, [L.Lvar start_ident; L.Lvar stop_ident], Loc_unknown)
383-
in
384-
let subsequent_test : L.lambda =
385-
Lprim (Pintcomp Cne, [L.Lvar ident; L.Lvar stop_ident], Loc_unknown)
386-
in
387-
let one : L.lambda = Lconst (Const_base (Const_int 1)) in
388-
let next_value_of_counter =
389-
match dir with
390-
| Upto -> L.Lprim (Paddint, [L.Lvar ident; one], Loc_unknown)
391-
| Downto -> L.Lprim (Psubint, [L.Lvar ident; one], Loc_unknown)
392-
in
393-
let lam : L.lambda =
394-
(* Care needs to be taken here not to cause overflow if, for an
395-
incrementing for-loop, the upper bound is [max_int]; likewise, for
396-
a decrementing for-loop, if the lower bound is [min_int]. *)
397-
Llet (Strict, Pgenval, start_ident, start,
398-
Llet (Strict, Pgenval, stop_ident, stop,
399-
Lifthenelse (first_test,
400-
Lstaticcatch (
401-
Lstaticraise (cont, [start]),
402-
(cont, [ident, Pgenval]),
403-
Lsequence (
404-
body,
405-
Lifthenelse (subsequent_test,
406-
Lstaticraise (cont, [next_value_of_counter]),
407-
L.lambda_unit))),
408-
L.lambda_unit)))
409-
in
410-
prepare env lam k
337+
prepare env start (fun start ->
338+
prepare env stop (fun stop ->
339+
prepare env body (fun body ->
340+
k (L.Lfor (ident, start, stop, dir, body)))))
411341
| Lassign (ident, lam) ->
412342
if not (Env.is_mutable env ident) then begin
413343
Misc.fatal_errorf "Lassign on non-mutable variable %a"
@@ -456,11 +386,9 @@ and prepare_option env lam_opt k =
456386
| Some lam -> prepare env lam (fun lam -> k (Some lam))
457387

458388
let run lam =
459-
recursive_static_catches := Numbers.Int.Set.empty;
460389
let current_unit_id =
461390
Compilation_unit.get_persistent_ident
462391
(Compilation_unit.get_current_exn ())
463392
in
464393
let env = Env.create ~current_unit_id in
465-
let lam = prepare env lam (fun lam -> lam) in
466-
lam, !recursive_static_catches
394+
prepare env lam (fun lam -> lam)

middle_end/flambda/from_lambda/prepare_lambda.mli

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,4 @@
1616

1717
(** Preparation of [Lambda] code before CPS and closure conversion. *)
1818

19-
(** The set of integers returned by [run] identifies all those [Lstaticcatch]
20-
handlers which are to be treated as recursive. (This is rather more
21-
straightforward than changing the type in [Lambda] to accommodate
22-
this). *)
23-
val run
24-
: Lambda.lambda
25-
-> Lambda.lambda * Numbers.Int.Set.t
19+
val run : Lambda.lambda -> Lambda.lambda

0 commit comments

Comments
 (0)