@@ -40,31 +40,23 @@ function is a primitive in reverse-mode AD.
4040struct ReverseMode <: Mode end
4141
4242"""
43- is_primitive( ::Type{Ctx}, ::Type{M }, sig) where {Ctx,M}
43+ _is_primitive(context ::Type, mode ::Type{<:Mode }, sig::Type{<:Tuple})
4444
45- Returns a `Bool` specifying whether the methods specified by `sig` are considered primitives
46- in the context of contexts of type `Ctx` in mode `M`.
47-
48- ```julia
49- is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64})
50- ```
51- will return if calling `sin(5.0)` should be treated as primitive when the context is a
52- `DefaultCtx`.
45+ This function is an internal implementation detail. It is used only by
46+ [`is_primitive`](@ref) and [`maybe_primitive`](@ref), and is used by these two functions in
47+ a very non-standard way. In particular, the value these functions return depends on the
48+ signatures of methods of this function, not what the methods do when invoked.
5349
54- Observe that this information means that whether or not something is a primitive in a
55- particular context depends only on static information, not any run-time information that
56- might live in a particular instance of `Ctx`.
50+ Generally speaking, you ought not to add methods to this function
51+ yourself, but make use of [`@is_primitive`](@ref).
5752"""
58- is_primitive (:: Type{MinimalCtx} , :: Type{<:Mode} , sig:: Type{<:Tuple} ) = false
59- function is_primitive (:: Type{DefaultCtx} , :: Type{M} , sig) where {M<: Mode }
60- return is_primitive (MinimalCtx, M, sig)
61- end
53+ function _is_primitive end
6254
6355"""
6456 @is_primitive context_type [mode_type] signature
6557
66- Creates a method of [`is_primitive`](@ref) which always returns `true` for the
67- `context_type`, and `signature` provided . For example
58+ Declares that calls with signature `signature` are primitives in `context_type` and
59+ `mode_type` . For example
6860```jldoctest
6961julia> using Mooncake: DefaultCtx, @is_primitive, is_primitive, ForwardMode, ReverseMode
7062
@@ -73,16 +65,14 @@ foo (generic function with 1 method)
7365
7466julia> @is_primitive DefaultCtx Tuple{typeof(foo),Float64}
7567
76- julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(foo),Float64})
68+ julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(foo),Float64}, Base.get_world_counter() )
7769true
7870
79- julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64})
71+ julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, Base.get_world_counter() )
8072true
8173```
8274Observe that this means that a rule is a primitive in all AD modes.
8375
84- You should implement more complicated methods of [`is_primitive`](@ref) in the usual way.
85-
8676Optionally, you can specify that a rule is only a primitive in a particular mode, eg.
8777```jldoctest
8878julia> using Mooncake: DefaultCtx, @is_primitive, is_primitive, ForwardMode, ReverseMode
@@ -92,10 +82,10 @@ bar (generic function with 1 method)
9282
9383julia> @is_primitive DefaultCtx ForwardMode Tuple{typeof(bar),Float64}
9484
95- julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(bar),Float64})
85+ julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(bar),Float64}, Base.get_world_counter() )
9686true
9787
98- julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(bar),Float64})
88+ julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(bar),Float64}, Base.get_world_counter() )
9989false
10090```
10191"""
@@ -109,10 +99,173 @@ end
10999
110100function _is_primitive_expression (Tctx, Tmode, sig)
111101 return quote
112- function Mooncake. is_primitive (
102+ function Mooncake. _is_primitive (
113103 :: Type{$(esc(Tctx))} , :: Type{<:$(Tmode)} , :: Type{<:$(esc(sig))}
114104 )
115105 return true
116106 end
117107 end
118108end
109+
110+ const _IS_PRIMITIVE_CACHE_DefaultCtx = IdDict {Any,Bool} ()
111+ const _IS_PRIMITIVE_CACHE_MinimalCtx = IdDict {Any,Bool} ()
112+
113+ """
114+ is_primitive(ctx::Type, mode::Type{<:Mode}, sig::Type{<:Tuple}, world::UInt)
115+
116+ Returns a `Bool` specifying whether the methods specified by `sig` are considered primitives
117+ in the context of context `ctx` in mode `mode` at world age `world`.
118+
119+ ```jldoctest
120+ julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode
121+
122+ julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64}, Base.get_world_counter())
123+ true
124+ ```
125+
126+ `world` is needed as rules which Mooncake derives are associated to a particular Julia world
127+ age. As a result, anything declared a primitive after the construction of a rule ought not
128+ to be considered a primitive by that rule. One can explicitly derive a new rule (eg. via
129+ [`build_frule`](@ref), [`build_rrule`](@ref), or a function from the higher-level interface
130+ such as [`prepare_derivative_cache`](@ref), [`prepare_pullback_cache`](@ref) or
131+ [`prepare_gradient_cache`](@ref)) after new `@is_primitive` declarations, should it be
132+ needed in cases where a rule has been previously derived. To see how this works, consider
133+ the following:
134+ ```jldoctest
135+ julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode, @is_primitive
136+
137+ julia> foo(x::Float64) = 5x
138+ foo (generic function with 1 method)
139+
140+ julia> old_world_age = Base.get_world_counter();
141+
142+ julia> @is_primitive DefaultCtx ReverseMode Tuple{typeof(foo),Float64}
143+
144+ julia> new_world_age = Base.get_world_counter();
145+
146+ julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, old_world_age)
147+ false
148+
149+ julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, new_world_age)
150+ true
151+ ```
152+ Observe that `is_primitive` returns `false` for the world age prior to declaring `foo` a
153+ primitive, but `true` afterwards. For more information on Julia's world age mechanism, see
154+ https://docs.julialang.org/en/v1/manual/methods/#Redefining-Methods .
155+ """
156+ function is_primitive (
157+ ctx:: Type{MinimalCtx} , mode:: Type{<:Mode} , sig:: Type{Tsig} , world:: UInt
158+ ) where {Tsig<: Tuple }
159+ @nospecialize sig
160+
161+ # We don't ever need to evaluate this function for abstract `mode`s, and there is a
162+ # performance penalty associated with doing so, so exclude the possibility.
163+ isconcretetype (mode) || throw (ArgumentError (" mode $mode is not a concrete type." ))
164+
165+ # Check to see whether any methods of `_is_primitive` exist which apply to this
166+ # ctx-mode-signature triple in world age `world`. If we have looked this up before,
167+ # return the answer from the cache.
168+
169+ tt = Tuple{typeof (_is_primitive),Type{ctx},Type{mode},Type{sig}}
170+ return get! (_IS_PRIMITIVE_CACHE_MinimalCtx, (world, tt)) do
171+ return ! isempty (Base. _methods_by_ftype (tt, - 1 , world))
172+ end
173+ end
174+
175+ function is_primitive (
176+ ctx:: Type{DefaultCtx} , mode:: Type{<:Mode} , sig:: Type{Tsig} , world:: UInt
177+ ) where {Tsig<: Tuple }
178+ @nospecialize sig
179+
180+ isconcretetype (mode) || throw (ArgumentError (" mode $mode is not a concrete type." ))
181+
182+ # This function returns `true` if the method is a primitive in either
183+ # `DefaultCtx` _or_ `MinimalCtx`.
184+ tt = Tuple{typeof (_is_primitive),Type{DefaultCtx},Type{mode},Type{sig}}
185+ return get! (_IS_PRIMITIVE_CACHE_DefaultCtx, (world, tt)) do
186+ return is_primitive (MinimalCtx, mode, sig, world) ||
187+ ! isempty (Base. _methods_by_ftype (tt, - 1 , world))
188+ end
189+ end
190+
191+ const _MAYBE_PRIMITIVE_CACHE_MinimalCtx = IdDict {Any,Bool} ()
192+ const _MAYBE_PRIMITIVE_CACHE_DefaultCtx = IdDict {Any,Bool} ()
193+
194+ """
195+ maybe_primitive(ctx::Type, mode::Type, sig::Type{<:Tuple}, world::UInt)
196+
197+ `true` if there exists `M<:mode`, and `S<:sig` such that
198+ `is_primitive(ctx, M, S, world)` returns `true`.
199+
200+ This functionality is used to determine whether or not it is safe to inline away a call
201+ site when performing abstract interpretation using a `MooncakeInterpreter`, which is only
202+ safe to do if the inferred argument types at the call site preclude the call being to a
203+ primitive.
204+
205+ For example, consider the following:
206+ ```jldoctest is_prim_example
207+ julia> using Mooncake: Mooncake, @is_primitive, DefaultCtx, ReverseMode
208+
209+ julia> foo(x) = 5x;
210+
211+ julia> @is_primitive DefaultCtx ReverseMode Tuple{typeof(foo),Float64}
212+
213+ ```
214+ This function agrees with [`is_primitive`](@ref) for fully inferred call sites:
215+ ```jldoctest is_prim_example
216+ julia> world = Base.get_world_counter();
217+
218+ julia> Mooncake.maybe_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, world)
219+ true
220+
221+ julia> Mooncake.maybe_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Int}, world)
222+ false
223+ ```
224+ However, it differs for call sites containing arguments whose types are not fully inferred.
225+ For example:
226+ ```jldoctest is_prim_example
227+ julia> Mooncake.is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Real}, world)
228+ false
229+
230+ julia> Mooncake.maybe_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Real}, world)
231+ true
232+ ```
233+ Per the definition at the top of this docstring, this function returns `true` because
234+ `Tuple{typeof(foo),Float64} <: Tuple{typeof(foo),Real}`.
235+ """
236+ function maybe_primitive (
237+ ctx:: Type{MinimalCtx} , mode:: Type{<:Mode} , sig:: Type{Tsig} , world:: UInt
238+ ) where {Tsig<: Tuple }
239+ @nospecialize sig
240+
241+ # We don't ever need to evaluate this function for abstract `mode`s, and there is a
242+ # performance penalty associated with doing so, so exclude the possibility.
243+ isconcretetype (mode) || throw (ArgumentError (" mode $mode is not a concrete type." ))
244+
245+ # Check to see whether any methods of `_is_primitive` exist which apply to any subtypes
246+ # of this ctx-mode-signature triple in world age `world`. If we have looked this up
247+ # before, return the answer from the cache.
248+ tt = Tuple{typeof (_is_primitive),Type{ctx},Type{mode},Type{<: sig }}
249+ return get! (_MAYBE_PRIMITIVE_CACHE_MinimalCtx, (world, tt)) do
250+ return ! isempty (Base. _methods_by_ftype (tt, - 1 , world))
251+ end
252+ end
253+
254+ function maybe_primitive (
255+ ctx:: Type{DefaultCtx} , mode:: Type{<:Mode} , sig:: Type{Tsig} , world:: UInt
256+ ) where {Tsig<: Tuple }
257+ @nospecialize sig
258+
259+ # We don't ever need to evaluate this function for abstract `mode`s, and there is a
260+ # performance penalty associated with doing so, so exclude the possibility.
261+ isconcretetype (mode) || throw (ArgumentError (" mode $mode is not a concrete type." ))
262+
263+ # Check to see whether any methods of `_is_primitive` exist which apply to any subtypes
264+ # of this ctx-mode-signature triple in world age `world`. If we have looked this up
265+ # before, return the answer from the cache.
266+ tt = Tuple{typeof (_is_primitive),Type{ctx},Type{mode},Type{<: sig }}
267+ return get! (_MAYBE_PRIMITIVE_CACHE_DefaultCtx, (world, tt)) do
268+ return maybe_primitive (MinimalCtx, mode, sig, world) ||
269+ ! isempty (Base. _methods_by_ftype (tt, - 1 , world))
270+ end
271+ end
0 commit comments