|
29 | 29 | function build_frule(args...; debug_mode=false, silence_debug_messages=true) |
30 | 30 | sig = _typeof(TestUtils.__get_primals(args)) |
31 | 31 | interp = get_interpreter(ForwardMode) |
32 | | - return build_frule(interp, sig; debug_mode, silence_debug_messages) |
| 32 | + return build_frule( |
| 33 | + interp, sig; debug_mode, silence_debug_messages, prevent_primitive_inlining=false |
| 34 | + ) |
33 | 35 | end |
34 | 36 |
|
35 | 37 | struct DualRuleInfo |
|
45 | 47 | debug_mode=false, |
46 | 48 | silence_debug_messages=true, |
47 | 49 | skip_world_age_check=false, |
| 50 | + prevent_primitive_inlining=false, |
48 | 51 | ) where {C} |
49 | 52 |
|
50 | 53 | Returns a function which performs forward-mode AD for `sig_or_mi`. Will derive a rule if |
51 | 54 | `sig_or_mi` is not a primitive. |
52 | 55 |
|
53 | 56 | Set `skip_world_age_check=true` when the interpreter's world age is intentionally older |
54 | 57 | than the current world (e.g., when building rules for MistyClosure which uses its own world). |
| 58 | +
|
| 59 | +Set `prevent_primitive_inlining=true` for forward-over-reverse mode. This prevents primitives |
| 60 | +from being inlined during optimization, which would expose foreigncalls that lack frule!!. |
55 | 61 | """ |
56 | 62 | function build_frule( |
57 | 63 | interp::MooncakeInterpreter{C}, |
58 | 64 | sig_or_mi; |
59 | 65 | debug_mode=false, |
60 | 66 | silence_debug_messages=true, |
61 | 67 | skip_world_age_check=false, |
| 68 | + prevent_primitive_inlining=false, |
62 | 69 | ) where {C} |
63 | 70 | @nospecialize sig_or_mi |
64 | 71 |
|
@@ -89,12 +96,16 @@ function build_frule( |
89 | 96 | try |
90 | 97 | # If we've already derived the OpaqueClosures and info, do not re-derive, just |
91 | 98 | # create a copy and pass in new shared data. |
92 | | - oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode, :forward)) |
| 99 | + oc_cache_key = ClosureCacheKey( |
| 100 | + interp.world, (sig_or_mi, debug_mode, :forward, prevent_primitive_inlining) |
| 101 | + ) |
93 | 102 | if haskey(interp.oc_cache, oc_cache_key) |
94 | 103 | return interp.oc_cache[oc_cache_key] |
95 | 104 | else |
96 | 105 | # Derive forward-pass IR, and shove in a `MistyClosure`. |
97 | | - dual_ir, captures, info = generate_dual_ir(interp, sig_or_mi; debug_mode) |
| 106 | + dual_ir, captures, info = generate_dual_ir( |
| 107 | + interp, sig_or_mi; debug_mode, prevent_primitive_inlining |
| 108 | + ) |
98 | 109 | dual_oc = misty_closure( |
99 | 110 | info.dual_ret_type, dual_ir, captures...; do_compile=true |
100 | 111 | ) |
@@ -162,7 +173,11 @@ struct DualInfo |
162 | 173 | end |
163 | 174 |
|
164 | 175 | function generate_dual_ir( |
165 | | - interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true |
| 176 | + interp::MooncakeInterpreter, |
| 177 | + sig_or_mi; |
| 178 | + debug_mode=false, |
| 179 | + do_inline=true, |
| 180 | + prevent_primitive_inlining=false, |
166 | 181 | ) |
167 | 182 | # Reset id count. This ensures that the IDs generated are the same each time this |
168 | 183 | # function runs. |
@@ -214,11 +229,15 @@ function generate_dual_ir( |
214 | 229 | captures_tuple = (captures...,) |
215 | 230 | dual_ir.argtypes[1] = _typeof(captures_tuple) |
216 | 231 |
|
217 | | - # Optimize dual IR using ForwardOverReverseInterpreter to prevent inlining primitives. |
218 | | - # This is needed for forward-over-reverse: if primitives are inlined, foreigncalls |
219 | | - # get exposed and we fail when trying to differentiate through them. |
220 | | - fo_r_interp = ForwardOverReverseInterpreter(interp) |
221 | | - dual_ir_opt = optimise_ir!(dual_ir; interp=fo_r_interp, do_inline) |
| 232 | + # Use ForwardOverReverseInterpreter when prevent_primitive_inlining is true. |
| 233 | + # This prevents primitives from being inlined, which is needed for forward-over-reverse |
| 234 | + # mode where inlining exposes foreigncalls that lack frule!!. |
| 235 | + if prevent_primitive_inlining |
| 236 | + fo_r_interp = ForwardOverReverseInterpreter(interp) |
| 237 | + dual_ir_opt = optimise_ir!(dual_ir; interp=fo_r_interp, do_inline) |
| 238 | + else |
| 239 | + dual_ir_opt = optimise_ir!(dual_ir; do_inline) |
| 240 | + end |
222 | 241 | return dual_ir_opt, captures_tuple, DualRuleInfo(isva, nargs, dual_ret_type(primal_ir)) |
223 | 242 | end |
224 | 243 |
|
|
0 commit comments