Skip to content

Commit 38f9746

Browse files
committed
Add prevent_primitive_inlining flag for forward-over-reverse mode
1 parent 94c026d commit 38f9746

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

src/interpreter/forward_mode.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ end
2929
function build_frule(args...; debug_mode=false, silence_debug_messages=true)
3030
sig = _typeof(TestUtils.__get_primals(args))
3131
interp = get_interpreter(ForwardMode)
32-
return build_frule(interp, sig; debug_mode, silence_debug_messages)
32+
return build_frule(
33+
interp, sig; debug_mode, silence_debug_messages, prevent_primitive_inlining=false
34+
)
3335
end
3436

3537
struct DualRuleInfo
@@ -45,20 +47,25 @@ end
4547
debug_mode=false,
4648
silence_debug_messages=true,
4749
skip_world_age_check=false,
50+
prevent_primitive_inlining=false,
4851
) where {C}
4952
5053
Returns a function which performs forward-mode AD for `sig_or_mi`. Will derive a rule if
5154
`sig_or_mi` is not a primitive.
5255
5356
Set `skip_world_age_check=true` when the interpreter's world age is intentionally older
5457
than the current world (e.g., when building rules for MistyClosure which uses its own world).
58+
59+
Set `prevent_primitive_inlining=true` for forward-over-reverse mode. This prevents primitives
60+
from being inlined during optimization, which would expose foreigncalls that lack frule!!.
5561
"""
5662
function build_frule(
5763
interp::MooncakeInterpreter{C},
5864
sig_or_mi;
5965
debug_mode=false,
6066
silence_debug_messages=true,
6167
skip_world_age_check=false,
68+
prevent_primitive_inlining=false,
6269
) where {C}
6370
@nospecialize sig_or_mi
6471

@@ -89,12 +96,16 @@ function build_frule(
8996
try
9097
# If we've already derived the OpaqueClosures and info, do not re-derive, just
9198
# create a copy and pass in new shared data.
92-
oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode, :forward))
99+
oc_cache_key = ClosureCacheKey(
100+
interp.world, (sig_or_mi, debug_mode, :forward, prevent_primitive_inlining)
101+
)
93102
if haskey(interp.oc_cache, oc_cache_key)
94103
return interp.oc_cache[oc_cache_key]
95104
else
96105
# Derive forward-pass IR, and shove in a `MistyClosure`.
97-
dual_ir, captures, info = generate_dual_ir(interp, sig_or_mi; debug_mode)
106+
dual_ir, captures, info = generate_dual_ir(
107+
interp, sig_or_mi; debug_mode, prevent_primitive_inlining
108+
)
98109
dual_oc = misty_closure(
99110
info.dual_ret_type, dual_ir, captures...; do_compile=true
100111
)
@@ -162,7 +173,11 @@ struct DualInfo
162173
end
163174

164175
function generate_dual_ir(
165-
interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
176+
interp::MooncakeInterpreter,
177+
sig_or_mi;
178+
debug_mode=false,
179+
do_inline=true,
180+
prevent_primitive_inlining=false,
166181
)
167182
# Reset id count. This ensures that the IDs generated are the same each time this
168183
# function runs.
@@ -214,11 +229,15 @@ function generate_dual_ir(
214229
captures_tuple = (captures...,)
215230
dual_ir.argtypes[1] = _typeof(captures_tuple)
216231

217-
# Optimize dual IR using ForwardOverReverseInterpreter to prevent inlining primitives.
218-
# This is needed for forward-over-reverse: if primitives are inlined, foreigncalls
219-
# get exposed and we fail when trying to differentiate through them.
220-
fo_r_interp = ForwardOverReverseInterpreter(interp)
221-
dual_ir_opt = optimise_ir!(dual_ir; interp=fo_r_interp, do_inline)
232+
# Use ForwardOverReverseInterpreter when prevent_primitive_inlining is true.
233+
# This prevents primitives from being inlined, which is needed for forward-over-reverse
234+
# mode where inlining exposes foreigncalls that lack frule!!.
235+
if prevent_primitive_inlining
236+
fo_r_interp = ForwardOverReverseInterpreter(interp)
237+
dual_ir_opt = optimise_ir!(dual_ir; interp=fo_r_interp, do_inline)
238+
else
239+
dual_ir_opt = optimise_ir!(dual_ir; do_inline)
240+
end
222241
return dual_ir_opt, captures_tuple, DualRuleInfo(isva, nargs, dual_ret_type(primal_ir))
223242
end
224243

src/interpreter/reverse_mode.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,12 +1216,8 @@ function generate_ir(
12161216
rvs_ir = pullback_ir(
12171217
primal_ir, Treturn, ad_stmts_blocks, pb_comms_insts, info, _typeof(shared_data)
12181218
)
1219-
# Use ForwardOverReverseInterpreter to prevent inlining primitives.
1220-
# This is needed for forward-over-reverse: if primitives are inlined, foreigncalls
1221-
# get exposed and we fail when trying to differentiate through them.
1222-
fo_r_interp = ForwardOverReverseInterpreter(interp)
1223-
opt_fwd_ir = optimise_ir!(IRCode(fwd_ir); interp=fo_r_interp, do_inline)
1224-
opt_rvs_ir = optimise_ir!(IRCode(rvs_ir); interp=fo_r_interp, do_inline)
1219+
opt_fwd_ir = optimise_ir!(IRCode(fwd_ir); do_inline)
1220+
opt_rvs_ir = optimise_ir!(IRCode(rvs_ir); do_inline)
12251221
return DerivedRuleInfo(
12261222
ir, opt_fwd_ir, fwd_ret_type, opt_rvs_ir, rvs_ret_type, shared_data, info, isva
12271223
)

src/rrules/misty_closures.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@ end
2424
# bypassing method table lookups that would require a current-world interpreter
2525
# 3. Any nested non-primitive calls use LazyFRule/DynamicFRule which obtain a current-world
2626
# interpreter via get_interpreter() at runtime
27+
# We set prevent_primitive_inlining=true because MistyClosure is used in forward-over-reverse
28+
# mode. Without this, primitives get inlined during optimization, exposing foreigncalls that
29+
# lack frule!! definitions.
2730
function _dual_mc(p::MistyClosure)
2831
mc_world = UInt(p.oc.world)
2932
interp = MooncakeInterpreter(DefaultCtx, ForwardMode; world=mc_world)
30-
return build_frule(interp, p; skip_world_age_check=true)
33+
return build_frule(
34+
interp, p; skip_world_age_check=true, prevent_primitive_inlining=true
35+
)
3136
end
3237

3338
tangent_type(::Type{<:MistyClosure}) = MistyClosureTangent

0 commit comments

Comments
 (0)