Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fd8f884
Fix forward-over-reverse, add tests, DI doesn't work yet
sunxd3 Dec 6, 2025
1bb341d
fix version specific issue
sunxd3 Dec 7, 2025
4869aa9
Use NativeInterpreter in frule_type to avoid maybe_primitive recursion
sunxd3 Dec 7, 2025
d6480be
add frule for _build_rule!
sunxd3 Dec 7, 2025
9159ef9
revert the nativeinterpreter change
sunxd3 Dec 7, 2025
a426235
add has_equal_data_internal for MistyClosureTangent
sunxd3 Dec 7, 2025
231a696
add rule for jl_genericmemory_owner to fix 1.11
sunxd3 Dec 7, 2025
1d28b33
add frule for jl_alloc_array_1d to fix 1.10 error
sunxd3 Dec 7, 2025
4bb3c61
aboid hardcoded UInt length
sunxd3 Dec 7, 2025
64410c6
formatting
sunxd3 Dec 7, 2025
3f8d7eb
more x86 compacy
sunxd3 Dec 8, 2025
029a777
deal with Union{} bottom type, make rule type Any for LazyFRule
sunxd3 Dec 8, 2025
03eb3c0
revert the Any type change for LazyFRule
sunxd3 Dec 8, 2025
14b6fcf
test forward over forward to see
sunxd3 Dec 8, 2025
4339183
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 8, 2025
5db1ce2
remove added DI tests
sunxd3 Dec 8, 2025
6578aa6
Update forward_over_reverse.jl
yebai Dec 9, 2025
5db16f0
Add skip_world_age_check kwarg to build_frule for MistyClosure support
sunxd3 Dec 9, 2025
09c496a
Add rrule!! for _build_rule! that throws on reverse-over-reverse
sunxd3 Dec 9, 2025
b7413d5
Remove unnecessary rules: literal_pow, push!, jl_genericmemory_owner
sunxd3 Dec 9, 2025
c2f14b0
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 12, 2025
2a9a680
Preserve primitive inlining policy in optimise_ir! for forward-over-r…
sunxd3 Dec 17, 2025
e18edfd
Restrict getfield frule!! to AbstractArray to avoid ambiguity with St…
sunxd3 Dec 17, 2025
e324759
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 17, 2025
3168a13
Fix Julia 1.12 Future assertion by using sv.interp for wrapper interp…
sunxd3 Dec 17, 2025
ed10d38
Revert interpreter forwarding changes (too hairy for now)
sunxd3 Dec 17, 2025
d79f375
Add back jl_genericmemory_owner frule/rrule for Julia 1.11+
sunxd3 Dec 17, 2025
c9109fe
Move `jl_genericmemory_owner` frule to test file.
sunxd3 Dec 17, 2025
7d258bf
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 17, 2025
f6ba807
version bump
sunxd3 Dec 17, 2025
d90977a
Merge branch 'main' into sunxd/f-o-r
yebai Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions src/interpreter/forward_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
# Check if a type contains Union{} (bottom type) anywhere in its structure.
# This can happen with unreachable code or failed type inference.
@inline contains_bottom_type(T) = _contains_bottom_type(T, Base.IdSet{Any}())

function _contains_bottom_type(T, seen::Base.IdSet{Any})
T === Union{} && return true
if T isa Union
return _contains_bottom_type(T.a, seen) || _contains_bottom_type(T.b, seen)
elseif T isa TypeVar
T in seen && return false
push!(seen, T)
return _contains_bottom_type(T.ub, seen)
elseif T isa UnionAll
T in seen && return false
push!(seen, T)
return _contains_bottom_type(T.body, seen)
elseif T isa DataType
T in seen && return false
push!(seen, T)
for p in T.parameters
_contains_bottom_type(p, seen) && return true
end
return false
else
return false
end
end

function build_frule(args...; debug_mode=false, silence_debug_messages=true)
sig = _typeof(TestUtils.__get_primals(args))
interp = get_interpreter(ForwardMode)
Expand All @@ -16,19 +44,27 @@ end
sig_or_mi;
debug_mode=false,
silence_debug_messages=true,
skip_world_age_check=false,
) where {C}

Returns a function which performs forward-mode AD for `sig_or_mi`. Will derive a rule if
`sig_or_mi` is not a primitive.

Set `skip_world_age_check=true` when the interpreter's world age is intentionally older
than the current world (e.g., when building rules for MistyClosure which uses its own world).
"""
function build_frule(
interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true
interp::MooncakeInterpreter{C},
sig_or_mi;
debug_mode=false,
silence_debug_messages=true,
skip_world_age_check=false,
) where {C}
@nospecialize sig_or_mi

# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
if !skip_world_age_check && Base.get_world_counter() > interp.world
throw(
ArgumentError(
"World age associated to interp is behind current world age. Please " *
Expand Down Expand Up @@ -331,7 +367,11 @@ function modify_fwd_ad_stmts!(
if isexpr(stmt, :invoke) || isexpr(stmt, :call)
raw_args = isexpr(stmt, :invoke) ? stmt.args[2:end] : stmt.args
sig_types = map(raw_args) do x
return CC.widenconst(get_forward_primal_type(info.primal_ir, x))
t = CC.widenconst(get_forward_primal_type(info.primal_ir, x))
# Replace types containing Union{} (unreachable code/failed inference)
# with Any. This allows the code to proceed; is_primitive will return
# false and we'll use dynamic rules that resolve types at runtime.
return contains_bottom_type(t) ? Any : t
end
sig = Tuple{sig_types...}
mi = isexpr(stmt, :invoke) ? get_mi(stmt.args[1]) : missing
Expand Down
82 changes: 82 additions & 0 deletions src/interpreter/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1893,6 +1893,88 @@ end
return rule.rule(args...)
end

# Forward-mode primitive for _build_rule! on LazyDerivedRule.
# This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter.
# The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building.
# Reverse-over-reverse is not supported; an rrule!! that throws is provided below.
@is_primitive MinimalCtx Tuple{typeof(_build_rule!),LazyDerivedRule,Tuple}

function frule!!(
::Dual{typeof(_build_rule!)},
lazy_rule_dual::Dual{<:LazyDerivedRule{sig}},
args_dual::Dual{<:Tuple},
) where {sig}
lazy_rule = primal(lazy_rule_dual)
lazy_tangent = tangent(lazy_rule_dual)
primal_args = primal(args_dual)
tangent_args = tangent(args_dual)

# Build rrule if not built (primal operation, no differentiation needed)
if !isdefined(lazy_rule, :rule)
interp = get_interpreter(ReverseMode)
lazy_rule.rule = build_rrule(interp, lazy_rule.mi; debug_mode=lazy_rule.debug_mode)
end
derived_rule = lazy_rule.rule

# Initialize the tangent of the derived rule if needed
rule_tangent_field = lazy_tangent.fields.rule
if !isdefined(rule_tangent_field, :tangent)
# Need to update the MutableTangent's fields with a new PossiblyUninitTangent
new_rule_tangent = PossiblyUninitTangent(zero_tangent(derived_rule))
lazy_tangent.fields = merge(lazy_tangent.fields, (; rule=new_rule_tangent))
rule_tangent_field = new_rule_tangent
end
derived_tangent = rule_tangent_field.tangent

# Forward-differentiate through the DerivedRule call.
# DerivedRule(args...) internally calls fwds_oc(args...) and returns (CoDual, Pullback)
fwds_oc = derived_rule.fwds_oc
fwds_oc_tangent = derived_tangent.fields.fwds_oc

# Handle varargs unflattening
isva = _isva(derived_rule)
nargs = derived_rule.nargs
N = length(primal_args)
uf_primal_args = __unflatten_codual_varargs(isva, primal_args, nargs)
uf_tangent_args = __unflatten_tangent_varargs(isva, tangent_args, nargs)

# Create dual args for frule!! call
dual_args = map(Dual, uf_primal_args, uf_tangent_args)

# Call frule!! on fwds_oc to get forward-differentiated result
dual_fwds_oc = Dual(fwds_oc, fwds_oc_tangent)
codual_result_dual = frule!!(dual_fwds_oc, dual_args...)

# Create Pullback and its tangent
pb_oc_ref = derived_rule.pb_oc_ref
pb_primal = Pullback(sig, pb_oc_ref, isva, N)
pb_tangent = Tangent((; pb_oc=derived_tangent.fields.pb_oc_ref))

# Return Dual of (CoDual, Pullback)
primal_result = (primal(codual_result_dual), pb_primal)
tangent_result = (tangent(codual_result_dual), pb_tangent)
return Dual(primal_result, tangent_result)
end

# Helper to unflatten tangent args similar to __unflatten_codual_varargs
function __unflatten_tangent_varargs(isva::Bool, tangent_args, ::Val{nargs}) where {nargs}
isva || return tangent_args
group_tangent = tangent_args[nargs:end]
return (tangent_args[1:(nargs - 1)]..., group_tangent)
end

# Reverse-over-reverse is not supported. Throw an informative error.
function rrule!!(
::CoDual{typeof(_build_rule!)}, ::CoDual{<:LazyDerivedRule}, ::CoDual{<:Tuple}
)
throw(
ArgumentError(
"Reverse-over-reverse differentiation is not supported. " *
"Encountered attempt to differentiate _build_rule! in reverse mode.",
),
)
end

"""
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Expand Down
79 changes: 79 additions & 0 deletions src/rrules/foreigncall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,85 @@ function rrule!!(
return uninit_fcodual(_foreigncall_(Val(:jl_string_ptr), x...)), pb!!
end

@static if VERSION < v"1.11-"
@generated function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_alloc_array_1d}},
::Dual{Val{Vector{P}}},
::Dual{Tuple{Val{Any},Val{Int}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
::Dual{Type{Vector{P}}},
n::Dual{Int},
args::Vararg{Dual},
) where {P}
T = tangent_type(P)
return quote
_n = primal(n)
y = ccall(:jl_alloc_array_1d, Vector{$P}, (Any, Int), Vector{$P}, _n)
dy = ccall(:jl_alloc_array_1d, Vector{$T}, (Any, Int), Vector{$T}, _n)
return Dual(y, dy)
end
end
@generated function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_alloc_array_2d}},
::Dual{Val{Matrix{P}}},
::Dual{Tuple{Val{Any},Val{Int},Val{Int}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
::Dual{Type{Matrix{P}}},
m::Dual{Int},
n::Dual{Int},
args::Vararg{Dual},
) where {P}
T = tangent_type(P)
return quote
_m, _n = primal(m), primal(n)
y = ccall(:jl_alloc_array_2d, Matrix{$P}, (Any, Int, Int), Matrix{$P}, _m, _n)
dy = ccall(:jl_alloc_array_2d, Matrix{$T}, (Any, Int, Int), Matrix{$T}, _m, _n)
return Dual(y, dy)
end
end
@generated function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_alloc_array_3d}},
::Dual{Val{Array{P,3}}},
::Dual{Tuple{Val{Any},Val{Int},Val{Int},Val{Int}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
::Dual{Type{Array{P,3}}},
l::Dual{Int},
m::Dual{Int},
n::Dual{Int},
args::Vararg{Dual},
) where {P}
T = tangent_type(P)
return quote
_l, _m, _n = primal(l), primal(m), primal(n)
y = ccall(
:jl_alloc_array_3d,
Array{$P,3},
(Any, Int, Int, Int),
Array{$P,3},
_l,
_m,
_n,
)
dy = ccall(
:jl_alloc_array_3d,
Array{$T,3},
(Any, Int, Int, Int),
Array{$T,3},
_l,
_m,
_n,
)
return Dual(y, dy)
end
end
end

function unexpected_foreigncall_error(name)
throw(
error(
Expand Down
1 change: 1 addition & 0 deletions src/rrules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
@from_chainrules MinimalCtx Tuple{typeof(deg2rad),IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(rad2deg),IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat}

@from_chainrules MinimalCtx Tuple{typeof(atan),P,P} where {P<:IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(max),P,P} where {P<:IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(min),P,P} where {P<:IEEEFloat}
Expand Down
29 changes: 29 additions & 0 deletions src/rrules/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,24 @@ end

# _new_ and _new_-adjacent rules for Memory, MemoryRef, and Array.

@static if VERSION >= v"1.12-"
@is_primitive MinimalCtx Tuple{typeof(Core.memorynew),Type{<:Memory},Int}
function frule!!(
::Dual{typeof(Core.memorynew)}, ::Dual{Type{Memory{P}}}, n::Dual{Int}
) where {P}
x = Core.memorynew(Memory{P}, primal(n))
dx = Core.memorynew(Memory{tangent_type(P)}, primal(n))
return Dual(x, dx)
end
function rrule!!(
::CoDual{typeof(Core.memorynew)}, ::CoDual{Type{Memory{P}}}, n::CoDual{Int}
) where {P}
x = Core.memorynew(Memory{P}, primal(n))
dx = Core.memorynew(Memory{tangent_type(P)}, primal(n))
return CoDual(x, dx), NoPullback((NoRData(), NoRData(), NoRData()))
end
end

@is_primitive MinimalCtx Tuple{Type{<:Memory},UndefInitializer,Int}
function frule!!(::Dual{Type{Memory{P}}}, ::Dual{UndefInitializer}, n::Dual{Int}) where {P}
x = Memory{P}(undef, primal(n))
Expand Down Expand Up @@ -908,6 +926,17 @@ function hand_written_rule_test_cases(rng_ctor, ::Val{:memory})
zip(mem_refs, sample_mem_ref_values),
)
test_cases = vcat(
@static(
if VERSION >= v"1.12-"
[
(true, :stability, nothing, Core.memorynew, Memory{Float64}, 5),
(true, :stability, nothing, Core.memorynew, Memory{Float64}, 10),
(true, :stability, nothing, Core.memorynew, Memory{Int}, 5),
]
else
[]
end
),

# Rules for `Memory`
(true, :stability, nothing, Memory{Float64}, undef, 5),
Expand Down
29 changes: 27 additions & 2 deletions src/rrules/misty_closures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@ struct MistyClosureTangent
dual_callable::Any
end

_dual_mc(p::MistyClosure) = build_frule(get_interpreter(ForwardMode), p)
# Build a forward-mode rule for a MistyClosure using its original world age.
# We skip the world age check because:
# 1. MistyClosure captures its own IR at creation time (see lookup_ir in ir_utils.jl)
# 2. build_frule -> generate_dual_ir -> lookup_ir(interp, mc) returns mc.ir[] directly,
# bypassing method table lookups that would require a current-world interpreter
# 3. Any nested non-primitive calls use LazyFRule/DynamicFRule which obtain a current-world
# interpreter via get_interpreter() at runtime
function _dual_mc(p::MistyClosure)
mc_world = UInt(p.oc.world)
interp = MooncakeInterpreter(DefaultCtx, ForwardMode; world=mc_world)
return build_frule(interp, p; skip_world_age_check=true)
end

tangent_type(::Type{<:MistyClosure}) = MistyClosureTangent

Expand Down Expand Up @@ -61,13 +72,27 @@ function _scale_internal(c::MaybeCache, a::Float64, t::T) where {T<:MistyClosure
return T(captures_tangent, t.dual_callable)
end

import .TestUtils: populate_address_map_internal, AddressMap
import .TestUtils: populate_address_map_internal, AddressMap, has_equal_data_internal
function populate_address_map_internal(
m::AddressMap, p::MistyClosure, t::MistyClosureTangent
)
return populate_address_map_internal(m, p.oc.captures, t.captures_tangent)
end

function has_equal_data_internal(
x::MistyClosureTangent,
y::MistyClosureTangent,
equal_undefs::Bool,
d::Dict{Tuple{UInt,UInt},Bool},
)
# Only compare captures_tangent. The dual_callable field is a forward-mode rule
# built on-demand by _dual_mc, which creates a new interpreter each time. Different
# interpreter instances produce different rule objects, even for the same MistyClosure.
# Since dual_callable is just a computational tool (not part of the tangent's value),
# two tangents with identical captures_tangent are mathematically equal.
return has_equal_data_internal(x.captures_tangent, y.captures_tangent, equal_undefs, d)
end

struct MistyClosureFData
captures_fdata::Any
dual_callable::Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))
using DifferentiationInterface, DifferentiationInterfaceTest
using Mooncake: Mooncake

# Test first-order differentiation (reverse mode)
test_differentiation(
[AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())];
excluded=SECOND_ORDER,
Expand Down
5 changes: 4 additions & 1 deletion test/front_matter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ using Mooncake:
_add_to_primal,
_diff,
_dot,
Dual,
zero_dual,
zero_codual,
codual_type,
rrule!!,
build_rrule,
build_frule,
value_and_gradient!!,
value_and_pullback!!,
NoFData,
Expand All @@ -44,7 +46,8 @@ using Mooncake:
get_interpreter,
Mode,
ForwardMode,
ReverseMode
ReverseMode,
MistyClosureTangent

using Mooncake:
CC,
Expand Down
Loading
Loading