Skip to content

Commit 3b5cfa4

Browse files
willtebbuttyebai
andauthored
Fix inlining avoidance (#759)
* Failing test * Failing tests in contexts * Refactor primitive system * Improve contexts testing * Update low_level_maths not-a-primitive tests * Update tools_for_rules is_primitive calls * Formatting * Fix test * Fix doctests * Expand on maybe_primitive semantics * Hopefully improve performance * Remove additional complexity and fix docs build * Fix stuff * Formatting * Bump version * Improve documentation: * Reference to is_primitive replaced with _is_primitive * Uncomment precompilation help * Improve docstring * Improve comment at use of maybe_primitive * Update src/interpreter/contexts.jl Co-authored-by: Hong Ge <[email protected]> Signed-off-by: Will Tebbutt <[email protected]> * Remove redundant line spacing * Update src/interpreter/contexts.jl Co-authored-by: Hong Ge <[email protected]> Signed-off-by: Will Tebbutt <[email protected]> * Explain the world age mechanism in is_primitive * Improve docstring * Fix some docstrings * Docstring improved + comment added to implementation * MinimalCtx is no longer a subtype of DefaultCtx. * fix typos * bugfix: fix dispatch error * rm the `world` argument from `test_rule` * Update Project.toml Signed-off-by: Hong Ge <[email protected]> --------- Signed-off-by: Will Tebbutt <[email protected]> Signed-off-by: Hong Ge <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 832b3cb commit 3b5cfa4

File tree

16 files changed

+313
-87
lines changed

16 files changed

+313
-87
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.179"
4+
version = "0.4.180"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/developer_documentation/forwards_mode_design.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ Hand-written rules are implemented by writing methods of two functions: `is_prim
7070

7171
### `is_primitive`
7272

73-
`is_primitive(::Type{<:Union{MinimalForwardsCtx, DefaultForwardsCtx}}, signature::Type{<:Tuple})` should return `true` if AD must attempt to differentiate a call by passing the arguments to `frule!!`, and `false` otherwise.
74-
The [`Mooncake.@is_primitive`](@ref) macro helps makes implementing this very easy.
73+
`is_primitive(::Type{<:Union{MinimalForwardsCtx, DefaultForwardsCtx}}, signature::Type{<:Tuple}, world)` must return `true` if AD must attempt to differentiate a call by passing the arguments to `frule!!`, and `false` otherwise.
74+
The [`Mooncake.@is_primitive`](@ref) macro must be used to extend to create new primitives.
7575

7676
### `frule!!`
7777

docs/src/interface.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ See [Tutorial](@ref) for more info.
77

88
```@docs; canonical=true
99
Mooncake.Config
10+
Mooncake.value_and_derivative!!
1011
Mooncake.value_and_gradient!!(::Mooncake.Cache, f::F, x::Vararg{Any, N}) where {F, N}
1112
Mooncake.value_and_pullback!!(::Mooncake.Cache, ȳ, f::F, x::Vararg{Any, N}) where {F, N}
13+
Mooncake.prepare_derivative_cache
1214
Mooncake.prepare_gradient_cache
1315
Mooncake.prepare_pullback_cache
1416
```

src/Mooncake.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,12 @@ function rrule!! end
8080
build_primitive_rrule(sig::Type{<:Tuple})
8181
8282
Construct an rrule for signature `sig`. For this function to be called in `build_rrule`, you
83-
must also ensure that `is_primitive(context_type, ReverseMode, sig)` is `true`. The callable
84-
returned by this must obey the rrule interface, but there are no restrictions on the type of
85-
callable itself. For example, you might return a callable `struct`. By default, this
86-
function returns `rrule!!` so, most of the time, you should just implement a method of
87-
`rrule!!`.
83+
must also ensure that a method of `_is_primitive(context_type, ReverseMode, sig)` exists,
84+
preferably by using the [@is_primitive](@ref) macro.
85+
The callable returned by this must obey the rrule interface, but there are no restrictions
86+
on the type of callable itself. For example, you might return a callable `struct`. By
87+
default, this function returns `rrule!!` so, most of the time, you should just implement a
88+
method of `rrule!!`.
8889
8990
# Extended Help
9091

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ function value_and_gradient!!(
594594
end
595595

596596
"""
597-
prepare_derivative_cache(f, x...)
597+
prepare_derivative_cache(fx...; kwargs...)
598598
599599
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info.
600600
"""

src/interpreter/abstract_interpretation.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,11 @@ function Core.Compiler.abstract_call_gf_by_type(
147147
sv::CC.AbsIntState,
148148
max_methods::Int,
149149
)
150-
is_primitive(C, M, atype) || return ret
150+
151+
# Check to see whether the call in question could possibly be a Mooncake primitive. If
152+
# it could be, set its call info such that it will not be inlined away.
153+
maybe_primitive(C, M, atype, interp.world) || return ret
154+
151155
# Insert a `NoInlineCallInfo` to prevent any potential inlining.
152156
@static if VERSION < v"1.12-"
153157
call = ret::CC.CallMeta

src/interpreter/contexts.jl

Lines changed: 178 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,31 +40,23 @@ function is a primitive in reverse-mode AD.
4040
struct ReverseMode <: Mode end
4141

4242
"""
43-
is_primitive(::Type{Ctx}, ::Type{M}, sig) where {Ctx,M}
43+
_is_primitive(context::Type, mode::Type{<:Mode}, sig::Type{<:Tuple})
4444
45-
Returns a `Bool` specifying whether the methods specified by `sig` are considered primitives
46-
in the context of contexts of type `Ctx` in mode `M`.
47-
48-
```julia
49-
is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64})
50-
```
51-
will return if calling `sin(5.0)` should be treated as primitive when the context is a
52-
`DefaultCtx`.
45+
This function is an internal implementation detail. It is used only by
46+
[`is_primitive`](@ref) and [`maybe_primitive`](@ref), and is used by these two functions in
47+
a very non-standard way. In particular, the value these functions return depends on the
48+
signatures of methods of this function, not what the methods do when invoked.
5349
54-
Observe that this information means that whether or not something is a primitive in a
55-
particular context depends only on static information, not any run-time information that
56-
might live in a particular instance of `Ctx`.
50+
Generally speaking, you ought not to add methods to this function
51+
yourself, but make use of [`@is_primitive`](@ref).
5752
"""
58-
is_primitive(::Type{MinimalCtx}, ::Type{<:Mode}, sig::Type{<:Tuple}) = false
59-
function is_primitive(::Type{DefaultCtx}, ::Type{M}, sig) where {M<:Mode}
60-
return is_primitive(MinimalCtx, M, sig)
61-
end
53+
function _is_primitive end
6254

6355
"""
6456
@is_primitive context_type [mode_type] signature
6557
66-
Creates a method of [`is_primitive`](@ref) which always returns `true` for the
67-
`context_type`, and `signature` provided. For example
58+
Declares that calls with signature `signature` are primitives in `context_type` and
59+
`mode_type`. For example
6860
```jldoctest
6961
julia> using Mooncake: DefaultCtx, @is_primitive, is_primitive, ForwardMode, ReverseMode
7062
@@ -73,16 +65,14 @@ foo (generic function with 1 method)
7365
7466
julia> @is_primitive DefaultCtx Tuple{typeof(foo),Float64}
7567
76-
julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(foo),Float64})
68+
julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(foo),Float64}, Base.get_world_counter())
7769
true
7870
79-
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64})
71+
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, Base.get_world_counter())
8072
true
8173
```
8274
Observe that this means that a rule is a primitive in all AD modes.
8375
84-
You should implement more complicated methods of [`is_primitive`](@ref) in the usual way.
85-
8676
Optionally, you can specify that a rule is only a primitive in a particular mode, eg.
8777
```jldoctest
8878
julia> using Mooncake: DefaultCtx, @is_primitive, is_primitive, ForwardMode, ReverseMode
@@ -92,10 +82,10 @@ bar (generic function with 1 method)
9282
9383
julia> @is_primitive DefaultCtx ForwardMode Tuple{typeof(bar),Float64}
9484
95-
julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(bar),Float64})
85+
julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(bar),Float64}, Base.get_world_counter())
9686
true
9787
98-
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(bar),Float64})
88+
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(bar),Float64}, Base.get_world_counter())
9989
false
10090
```
10191
"""
@@ -109,10 +99,173 @@ end
10999

110100
function _is_primitive_expression(Tctx, Tmode, sig)
111101
return quote
112-
function Mooncake.is_primitive(
102+
function Mooncake._is_primitive(
113103
::Type{$(esc(Tctx))}, ::Type{<:$(Tmode)}, ::Type{<:$(esc(sig))}
114104
)
115105
return true
116106
end
117107
end
118108
end
109+
110+
const _IS_PRIMITIVE_CACHE_DefaultCtx = IdDict{Any,Bool}()
111+
const _IS_PRIMITIVE_CACHE_MinimalCtx = IdDict{Any,Bool}()
112+
113+
"""
114+
is_primitive(ctx::Type, mode::Type{<:Mode}, sig::Type{<:Tuple}, world::UInt)
115+
116+
Returns a `Bool` specifying whether the methods specified by `sig` are considered primitives
117+
in the context of context `ctx` in mode `mode` at world age `world`.
118+
119+
```jldoctest
120+
julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode
121+
122+
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64}, Base.get_world_counter())
123+
true
124+
```
125+
126+
`world` is needed as rules which Mooncake derives are associated to a particular Julia world
127+
age. As a result, anything declared a primitive after the construction of a rule ought not
128+
to be considered a primitive by that rule. One can explicitly derive a new rule (eg. via
129+
[`build_frule`](@ref), [`build_rrule`](@ref), or a function from the higher-level interface
130+
such as [`prepare_derivative_cache`](@ref), [`prepare_pullback_cache`](@ref) or
131+
[`prepare_gradient_cache`](@ref)) after new `@is_primitive` declarations, should it be
132+
needed in cases where a rule has been previously derived. To see how this works, consider
133+
the following:
134+
```jldoctest
135+
julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode, @is_primitive
136+
137+
julia> foo(x::Float64) = 5x
138+
foo (generic function with 1 method)
139+
140+
julia> old_world_age = Base.get_world_counter();
141+
142+
julia> @is_primitive DefaultCtx ReverseMode Tuple{typeof(foo),Float64}
143+
144+
julia> new_world_age = Base.get_world_counter();
145+
146+
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, old_world_age)
147+
false
148+
149+
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, new_world_age)
150+
true
151+
```
152+
Observe that `is_primitive` returns `false` for the world age prior to declaring `foo` a
153+
primitive, but `true` afterwards. For more information on Julia's world age mechanism, see
154+
https://docs.julialang.org/en/v1/manual/methods/#Redefining-Methods .
155+
"""
156+
function is_primitive(
157+
ctx::Type{MinimalCtx}, mode::Type{<:Mode}, sig::Type{Tsig}, world::UInt
158+
) where {Tsig<:Tuple}
159+
@nospecialize sig
160+
161+
# We don't ever need to evaluate this function for abstract `mode`s, and there is a
162+
# performance penalty associated with doing so, so exclude the possibility.
163+
isconcretetype(mode) || throw(ArgumentError("mode $mode is not a concrete type."))
164+
165+
# Check to see whether any methods of `_is_primitive` exist which apply to this
166+
# ctx-mode-signature triple in world age `world`. If we have looked this up before,
167+
# return the answer from the cache.
168+
169+
tt = Tuple{typeof(_is_primitive),Type{ctx},Type{mode},Type{sig}}
170+
return get!(_IS_PRIMITIVE_CACHE_MinimalCtx, (world, tt)) do
171+
return !isempty(Base._methods_by_ftype(tt, -1, world))
172+
end
173+
end
174+
175+
function is_primitive(
176+
ctx::Type{DefaultCtx}, mode::Type{<:Mode}, sig::Type{Tsig}, world::UInt
177+
) where {Tsig<:Tuple}
178+
@nospecialize sig
179+
180+
isconcretetype(mode) || throw(ArgumentError("mode $mode is not a concrete type."))
181+
182+
# This function returns `true` if the method is a primitive in either
183+
# `DefaultCtx` _or_ `MinimalCtx`.
184+
tt = Tuple{typeof(_is_primitive),Type{DefaultCtx},Type{mode},Type{sig}}
185+
return get!(_IS_PRIMITIVE_CACHE_DefaultCtx, (world, tt)) do
186+
return is_primitive(MinimalCtx, mode, sig, world) ||
187+
!isempty(Base._methods_by_ftype(tt, -1, world))
188+
end
189+
end
190+
191+
const _MAYBE_PRIMITIVE_CACHE_MinimalCtx = IdDict{Any,Bool}()
192+
const _MAYBE_PRIMITIVE_CACHE_DefaultCtx = IdDict{Any,Bool}()
193+
194+
"""
195+
maybe_primitive(ctx::Type, mode::Type, sig::Type{<:Tuple}, world::UInt)
196+
197+
`true` if there exists `M<:mode`, and `S<:sig` such that
198+
`is_primitive(ctx, M, S, world)` returns `true`.
199+
200+
This functionality is used to determine whether or not it is safe to inline away a call
201+
site when performing abstract interpretation using a `MooncakeInterpreter`, which is only
202+
safe to do if the inferred argument types at the call site preclude the call being to a
203+
primitive.
204+
205+
For example, consider the following:
206+
```jldoctest is_prim_example
207+
julia> using Mooncake: Mooncake, @is_primitive, DefaultCtx, ReverseMode
208+
209+
julia> foo(x) = 5x;
210+
211+
julia> @is_primitive DefaultCtx ReverseMode Tuple{typeof(foo),Float64}
212+
213+
```
214+
This function agrees with [`is_primitive`](@ref) for fully inferred call sites:
215+
```jldoctest is_prim_example
216+
julia> world = Base.get_world_counter();
217+
218+
julia> Mooncake.maybe_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, world)
219+
true
220+
221+
julia> Mooncake.maybe_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Int}, world)
222+
false
223+
```
224+
However, it differs for call sites containing arguments whose types are not fully inferred.
225+
For example:
226+
```jldoctest is_prim_example
227+
julia> Mooncake.is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Real}, world)
228+
false
229+
230+
julia> Mooncake.maybe_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Real}, world)
231+
true
232+
```
233+
Per the definition at the top of this docstring, this function returns `true` because
234+
`Tuple{typeof(foo),Float64} <: Tuple{typeof(foo),Real}`.
235+
"""
236+
function maybe_primitive(
237+
ctx::Type{MinimalCtx}, mode::Type{<:Mode}, sig::Type{Tsig}, world::UInt
238+
) where {Tsig<:Tuple}
239+
@nospecialize sig
240+
241+
# We don't ever need to evaluate this function for abstract `mode`s, and there is a
242+
# performance penalty associated with doing so, so exclude the possibility.
243+
isconcretetype(mode) || throw(ArgumentError("mode $mode is not a concrete type."))
244+
245+
# Check to see whether any methods of `_is_primitive` exist which apply to any subtypes
246+
# of this ctx-mode-signature triple in world age `world`. If we have looked this up
247+
# before, return the answer from the cache.
248+
tt = Tuple{typeof(_is_primitive),Type{ctx},Type{mode},Type{<:sig}}
249+
return get!(_MAYBE_PRIMITIVE_CACHE_MinimalCtx, (world, tt)) do
250+
return !isempty(Base._methods_by_ftype(tt, -1, world))
251+
end
252+
end
253+
254+
function maybe_primitive(
255+
ctx::Type{DefaultCtx}, mode::Type{<:Mode}, sig::Type{Tsig}, world::UInt
256+
) where {Tsig<:Tuple}
257+
@nospecialize sig
258+
259+
# We don't ever need to evaluate this function for abstract `mode`s, and there is a
260+
# performance penalty associated with doing so, so exclude the possibility.
261+
isconcretetype(mode) || throw(ArgumentError("mode $mode is not a concrete type."))
262+
263+
# Check to see whether any methods of `_is_primitive` exist which apply to any subtypes
264+
# of this ctx-mode-signature triple in world age `world`. If we have looked this up
265+
# before, return the answer from the cache.
266+
tt = Tuple{typeof(_is_primitive),Type{ctx},Type{mode},Type{<:sig}}
267+
return get!(_MAYBE_PRIMITIVE_CACHE_DefaultCtx, (world, tt)) do
268+
return maybe_primitive(MinimalCtx, mode, sig, world) ||
269+
!isempty(Base._methods_by_ftype(tt, -1, world))
270+
end
271+
end

src/interpreter/forward_mode.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ struct DualRuleInfo
1010
dual_ret_type::Type
1111
end
1212

13+
"""
14+
build_frule(
15+
interp::MooncakeInterpreter{C},
16+
sig_or_mi;
17+
debug_mode=false,
18+
silence_debug_messages=true,
19+
) where {C}
20+
21+
Returns a function which performs forward-mode AD for `sig_or_mi`. Will derive a rule if
22+
`sig_or_mi` is not a primitive.
23+
"""
1324
function build_frule(
1425
interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true
1526
) where {C}
@@ -33,7 +44,9 @@ function build_frule(
3344

3445
# If we have a hand-coded rule, just use that.
3546
sig = _get_sig(sig_or_mi)
36-
is_primitive(C, ForwardMode, sig) && return (debug_mode ? DebugFRule(frule!!) : frule!!)
47+
if is_primitive(C, ForwardMode, sig, interp.world)
48+
return (debug_mode ? DebugFRule(frule!!) : frule!!)
49+
end
3750

3851
# We don't have a hand-coded rule, so derive one.
3952
lock(MOONCAKE_INFERENCE_LOCK)
@@ -334,7 +347,8 @@ function modify_fwd_ad_stmts!(
334347
return uninit_dual(get_const_primal_value(arg))
335348
end
336349

337-
if is_primitive(context_type(info.interp), ForwardMode, sig)
350+
interp = info.interp
351+
if is_primitive(context_type(interp), ForwardMode, sig, interp.world)
338352
replace_call!(dual_ir, ssa, Expr(:call, frule!!, dual_args...))
339353
else
340354
dm = info.debug_mode
@@ -423,7 +437,7 @@ function frule_type(
423437
interp::MooncakeInterpreter{C}, mi::CC.MethodInstance; debug_mode
424438
) where {C}
425439
primal_sig = _get_sig(mi)
426-
if is_primitive(C, ForwardMode, primal_sig)
440+
if is_primitive(C, ForwardMode, primal_sig, interp.world)
427441
return debug_mode ? DebugFRule{typeof(frule!!)} : typeof(frule!!)
428442
end
429443
ir, _ = lookup_ir(interp, mi)

src/interpreter/reverse_mode.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
702702

703703
# Construct signature, and determine how the rrule is to be computed.
704704
sig = Tuple{arg_types...}
705-
raw_rule = if is_primitive(context_type(info.interp), ReverseMode, sig)
705+
interp = info.interp
706+
raw_rule = if is_primitive(context_type(interp), ReverseMode, sig, interp.world)
706707
rrule!! # intrinsic / builtin / thing we provably have rule for
707708
elseif is_invoke
708709
mi = get_mi(stmt.args[1])
@@ -1123,7 +1124,7 @@ function build_rrule(
11231124

11241125
# If we have a hand-coded rule, just use that.
11251126
sig = _get_sig(sig_or_mi)
1126-
if is_primitive(C, ReverseMode, sig)
1127+
if is_primitive(C, ReverseMode, sig, interp.world)
11271128
rule = build_primitive_rrule(sig)
11281129
return (debug_mode ? DebugRRule(rule) : rule)
11291130
end
@@ -1900,7 +1901,7 @@ important for performance in dynamic dispatch, and to ensure that recursion work
19001901
properly.
19011902
"""
19021903
function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}
1903-
if is_primitive(C, ReverseMode, _get_sig(sig_or_mi))
1904+
if is_primitive(C, ReverseMode, _get_sig(sig_or_mi), interp.world)
19041905
rule = build_primitive_rrule(_get_sig(sig_or_mi))
19051906
return debug_mode ? DebugRRule{typeof(rule)} : typeof(rule)
19061907
end

0 commit comments

Comments
 (0)