Skip to content

Conversation

@sunxd3
Copy link
Collaborator

@sunxd3 sunxd3 commented Dec 6, 2025

Ref: #826

This PR aims to enable forward-over-reverse for second order derivatives. Many of the necessary plumbings are already implemented before this PR. This PR smooths out the remaining edges.

Fixes:

  1. World age fix for MistyClosure: Forward mode now uses MistyClosure's original world age instead of current world, fixing failures when zero_tangent(pb_mc) was called.
  2. Union{} type issue (this is somewhat 1.10 specific) Union{} can showup when type inference fails, the approach here is to replace them with Any
  3. Add some necessary rules to avoid hanging when testing some particular functions
    1. Core.memorynew
    2. LazyDerivedRule: avoid diff through get_interpreter (jl_get_world_counter)
    3. jl_genericmemory_owner for 1.11
    4. jl_alloc_array_1d/2d/3d

What doesn't work yet (hangs)

  1. DI.hessian with SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))

Closes #632

@sunxd3 sunxd3 mentioned this pull request Dec 6, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Dec 7, 2025

Mooncake.jl documentation for PR #878 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR878/

@github-actions
Copy link
Contributor

github-actions bot commented Dec 7, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                      Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                     String │   String │   String │      String │  String │      String │ String │
├────────────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│                   sum_1000 │ 100.0 ns │      1.9 │         1.9 │     1.0 │        5.61 │   8.31 │
│                  _sum_1000 │ 941.0 ns │     6.57 │        1.01 │  1470.0 │        33.7 │   1.09 │
│               sum_sin_1000 │  6.56 μs │     2.44 │        1.38 │    1.67 │        10.4 │   2.18 │
│              _sum_sin_1000 │  5.23 μs │     2.99 │         2.2 │   249.0 │        13.4 │   2.49 │
│                   kron_sum │ 317.0 μs │     41.5 │        2.96 │    5.68 │       205.0 │   9.78 │
│              kron_view_sum │ 344.0 μs │     40.7 │        3.38 │    11.1 │       221.0 │   6.54 │
│      naive_map_sin_cos_exp │  2.15 μs │     2.38 │         1.4 │ missing │        7.12 │   2.34 │
│            map_sin_cos_exp │  2.12 μs │     2.72 │        1.45 │    1.51 │        6.18 │   2.93 │
│      broadcast_sin_cos_exp │  2.27 μs │      2.4 │        1.37 │    2.32 │        1.47 │   2.25 │
│                 simple_mlp │ 389.0 μs │     4.72 │         3.1 │    1.65 │        8.61 │    3.3 │
│                     gp_lml │ 248.0 μs │     8.36 │        2.08 │    3.68 │     missing │   6.55 │
│ turing_broadcast_benchmark │  1.75 ms │     4.94 │        3.48 │ missing │        27.3 │   2.31 │
│         large_single_block │ 380.0 ns │     4.53 │        2.03 │  4410.0 │        31.2 │   2.24 │
└────────────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@Technici4n Technici4n self-requested a review December 7, 2025 10:37
@sunxd3 sunxd3 requested a review from yebai December 8, 2025 11:15
@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 8, 2025

@yebai @Technici4n this is ready for a review.

Integration with DI is not complete, I have some guess of reasons, but not certain yet. The PR is probably big enough.

One thing I don't know how to resolve: examples like x -> sum(x .* x) will hang on 1.10. Function like these uses DerivedFRule and frule_type on 1.10 hangs. A hot fix is something like 03eb3c0. it works, but creates allocations.
x -> sum(x .* x) is working fine on 1.10 in fact, I mistake "taking a bit long" as "hangs"

@yebai
Copy link
Member

yebai commented Dec 8, 2025

One thing I don't know how to resolve: examples like x -> sum(x .* x) will hang on 1.10. Function like these uses DerivedFRule and frule_type on 1.10 hangs.

This might be an actual Julia compiler bug. Can you try to create a MWE so we can report and patch over it?

Integration with DI is not complete, I have some guess of reasons, but not certain yet. The PR is probably big enough.

Probably okay to keep fixing DI related issues in this PR if they are not major. Also happy to start a seprate PR otherwise.

@gdalle
Copy link
Collaborator

gdalle commented Dec 8, 2025

Very cool that you found a way to do this without touching DI internals @sunxd3, congrats and thanks for the effort!!!
Let me know when this is in a good enough state, I'll put it through some more extensive DI testing

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 8, 2025

@yebai

This might be an actual Julia compiler bug.

It might be, but whatever it is, it's fixed in 1.11. Curious what it is, though.
The tests pass, but is taking quite a long while (https://github.com/chalk-lab/Mooncake.jl/actions/runs/20024598254/job/57418981172?pr=878).

My guess is that type inference failed, produce Union{}, but handled by the Union{} fix, so the test didn't fail. see #878 (comment)

I would prefer to stop here and fix DI integration in another PR.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 8, 2025

Thanks @gdalle, unfortunately, the DI test doesn't work yet.

I wanted to keep it simple, and use Mooncake internal functions to make minimal f-o-r work.

I might start DI integration in another PR, I don't think we would need extension, but def need more investigation.

@gdalle
Copy link
Collaborator

gdalle commented Dec 8, 2025

I agree, 2nd order should work inside of Mooncake instead of requiring DI hacks. I just meant that DI has more possible tests to run than what is currently inside your integration tests

Copy link
Collaborator

@Technici4n Technici4n left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool! I will have a deeper look later this week. As always, I am a bit suspicious of new rules that need to be added. 😉

# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
# than the current world age. Exception: MistyClosure uses its own world age.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe add a kwarg to disable the check? Special-casing MC here is a bit strange.

@zero_derivative MinimalCtx Tuple{typeof(Base.has_free_typevars),Any}

# jl_genericmemory_owner is used by Base.dataids to determine memory aliasing.
# The result is just a pointer for aliasing detection, so it has zero derivative.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use @zero_derivative?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that @zero_derivative needs the function name to the first argument. _foreigncall_ is special case because because the first argument is _foreigncall_ not the ccall name.

Copy link
Collaborator Author

@sunxd3 sunxd3 Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/chalk-lab/Mooncake.jl/actions/runs/20058347212/job/57528828789?pr=878 interesting only fails on 1.11 (I removed the rule temporarily to see where it traps)

Copy link
Collaborator Author

@sunxd3 sunxd3 Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a bit more digging. This turns out to be a pretty interesting hidden issue.

The question to ask here is why

@static if VERSION >= v"1.11-"
@zero_derivative MinimalCtx Tuple{typeof(Random.hash_seed),Vararg}
@zero_derivative MinimalCtx Tuple{typeof(Base.dataids),Memory}
end
didn't prevent the error.

The answer is that the call is inlined away, and that's why jl_genericmemory_owner is exposed (1.12 doesn't use jl_genericmemory_owner anymore, so this test doesn't error on 1.12).

To verify, with julia +1.11

using Mooncake
using Mooncake: CC, get_interpreter, BugPatchInterpreter, lookup_ir, optimise_ir!, ForwardMode

f(x::Vector{Float64}) = Base.dataids(x.ref.mem)

interp = get_interpreter(ForwardMode)

# IR from lookup_ir - dataids is NOT inlined (MooncakeInterpreter prevents it)
ir, _ = lookup_ir(interp, Tuple{typeof(f), Vector{Float64}})

# IR after optimise_ir! - dataids IS inlined, exposing jl_genericmemory_owner foreigncall
ir_opt = optimise_ir!(CC.copy(ir); do_inline=true)

gives

julia> ir  # dataids is NOT inlined
1%1 = Base.getfield(_2, :ref)::MemoryRef{Float64}%2 = Base.getfield(%1, :mem)::Memory{Float64}%3 = invoke Base.dataids(%2::Memory{Float64})::Tuple{UInt64}   # <-- stays as invoke
└──      return %3

julia> ir_opt  # dataids IS inlined, exposing foreigncall
1%1  = Base.getfield(_2, :ref)::MemoryRef{Float64}%2  = Base.getfield(%1, :mem)::Memory{Float64}%3  = foreigncall(:jl_genericmemory_owner, ...)::Any           # <-- inlined!%4  = (%3 isa Memory{Float64})::Bool
└──       goto #3 if not %4
2%6  = π (%3, Memory{Float64})
└──       goto #4
3nothing
4%9  = φ (#2 => %6, #3 => %2)::Memory{Float64}%10 = Base.getfield(%9, :ptr)::Ptr{Nothing}%11 = Base.bitcast(Ptr{Float64}, %10)::Ptr{Float64}%12 = Core.bitcast(Core.UInt, %11)::UInt64%13 = Core.tuple(%12)::Tuple{UInt64}
└──       return %13

Why This Only Affects Forward-over-Reverse

When building rules, lookup_ir uses MooncakeInterpreter which prevents inlining primitives. But then optimise_ir! uses BugPatchInterpreter which has no such protection—it inlines everything including rrule!!(dataids, ...) down to the raw foreigncall.

In pure reverse mode, rule(...) executes this optimized IR at runtime. The foreigncall just runs and returns a value.

In forward-over-reverse, build_frule calls lookup_ir(MistyClosure) which returns the already-optimized IR containing the foreigncall. Now we try to differentiate through this IR, which requires an frule!! for the foreigncall—but none exists.

# Forward-mode primitive for _build_rule! on LazyDerivedRule.
# This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter.
# The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building.
# Only primitive in ForwardMode - reverse mode uses derived rule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is said derived rule working? If not it's likely better to define a rule that throws rather than letting Mooncake try to differentiate through it and fail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a rrule saying we don't support reverse-o-r yet

@codecov
Copy link

codecov bot commented Dec 8, 2025

Codecov Report

❌ Patch coverage is 79.00000% with 21 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/rrules/foreigncall.jl 28.57% 15 Missing ⚠️
src/interpreter/reverse_mode.jl 89.18% 4 Missing ⚠️
src/rrules/memory.jl 83.33% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@yebai
Copy link
Member

yebai commented Dec 8, 2025

This might be an actual Julia compiler bug.

It might be, but whatever it is, it's fixed in 1.11. Curious what it is, though.
The tests pass, but is taking quite a long while (https://github.com/chalk-lab/Mooncake.jl/actions/runs/20024598254/job/57418981172?pr=878). My guess is that type inference failed, produce Union{}, but handled by the Union{} fix, so the test didn't fail.

@sunxd3 That’s possible. For the sake of thoroughness, let’s get to the bottom of this. For example, if your guess above is correct, we ought to see Union{} returned by frule_type for the MWE that’s causing issues.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 9, 2025

On the Tuple field type cannot be Union{} error:

With the following code example,

using Core: ReturnNode
using Core.Compiler
const CC = Core.Compiler
using Mooncake: build_rrule, lookup_ir, get_interpreter, ReverseMode

f(x) = sum(x .* x)
x = [1.0, 2.0]

tt = Tuple{typeof(f),typeof(x)}
interp = get_interpreter(ReverseMode)
julia_ir, _ = lookup_ir(interp, tt)

rule = build_rrule(f, x)
interp = get_interpreter(ReverseMode)
mooncake_ir, _ = lookup_ir(interp, rule.fwds_oc)

for i in 1:length(mooncake_ir.stmts)
    mooncake_ir.stmts[i][:type] === Union{} || continue
    stmt = @static if VERSION >= v"1.11-"
        mooncake_ir.stmts[i][:stmt]
    else
        mooncake_ir.stmts[i][:inst]
    end
    typ = mooncake_ir.stmts[i][:type]
    println("SSA $i :: $typ = $(typeof(stmt).name.name) - $stmt")
end
On Julia 1.10, this returns (click to expand)
SSA 124 :: Union{} = GotoIfNot - goto %44 if not %123
SSA 127 :: Union{} = GotoNode - goto %47
SSA 129 :: Union{} = GotoIfNot - goto %46 if not %128
SSA 132 :: Union{} = GotoNode - goto %47
SSA 134 :: Union{} = GotoNode - goto %47
SSA 144 :: Union{} = Expr - (Mooncake.rrule!!)(...)
SSA 477 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 504 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 544 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 600 :: Union{} = Expr - invoke MistyClosure(throwdm)(...)
SSA 602 :: Union{} = PhiNode - φ ()                          # ← ORPHANED!
SSA 604 :: Union{} = Expr - invoke Mooncake._build_rule!(...)
SSA 606 :: Union{} = PhiNode - φ ()                          # ← ORPHANED!
SSA 607 :: Union{} = Expr - (getfield)(%606, 1)              # ← USES Union{} phi!
SSA 608 :: Union{} = Expr - (getfield)(%606, 2)              # ← USES Union{} phi!
SSA 609 :: Union{} = Expr - (tuple)(%608)
SSA 610 :: Union{} = Expr - (push!)(%9, %609)
On Julia 1.12, this returns (click to expand)
SSA 45 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 46 :: Union{} = ReturnNode - unreachable     # ← CLEAN!
SSA 79 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 80 :: Union{} = ReturnNode - unreachable     # ← CLEAN!
# ... (pattern continues: throw call → unreachable, no orphaned phi nodes)
SSA 969 :: Union{} = Expr - invoke MistyClosure(throwdm)(...)
SSA 970 :: Union{} = ReturnNode - unreachable    # ← CLEAN!
SSA 971 :: Union{} = Expr - invoke Mooncake._build_rule!(...)
SSA 972 :: Union{} = ReturnNode - unreachable    # ← CLEAN!
# ... (61 total Union{} statements, all are Expr or ReturnNode)

The Culprit

The orphaned φ () (empty PhiNode) has Union{} type. Operations that use it also get Union{} type:

SSA 606 :: Union{} = PhiNode - φ ()              # ← orphaned empty phi
SSA 607 :: Union{} = Expr - (getfield)(%606, 1)  # ← uses the Union{} phi

When Mooncake processes getfield(%606, 1), it builds:

sig_types = [type_of(%606), type_of(1)]  # = [Union{}, Int]
sig = Tuple{Union{}, Int}                 # ← ERROR on Julia 1.10+!

I am still trying to nail down which julia compiler PRs fix these. The fix should be applied to IncrementalCompact, mostly likely JuliaLang/julia#52614 and JuliaLang/julia#52863.

Why Other Union{} Statements Are Fine

Mooncake handles other Union{}-typed statements gracefully:

  • Expr (throw calls): Legitimate "never returns" calls - Mooncake has rules for these
  • ReturnNode - unreachable: Mooncake skips these (isdefined(stmt, :val) || return nothing)

Only the orphaned PhiNode and operations using it cause problems, because Mooncake tries to build type signatures containing Union{}.

The fix currently in this PR is fine IMO. I would consider this a julia compiler bug, but it was fixed in 1.11 release.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 11, 2025

Repost #878 (comment)

A proper fix might be adding a new interpreter that combine both interpreters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DI.hessian: Tuple field type cannot be Union{}

5 participants