-
Notifications
You must be signed in to change notification settings - Fork 24
Enable second-order differentiation via forward-over-reverse #878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Mooncake.jl documentation for PR #878 is available at: |
|
Performance Ratio: |
|
@yebai @Technici4n this is ready for a review. Integration with DI is not complete, I have some guess of reasons, but not certain yet. The PR is probably big enough.
|
This might be an actual Julia compiler bug. Can you try to create a MWE so we can report and patch over it?
Probably okay to keep fixing DI related issues in this PR if they are not major. Also happy to start a seprate PR otherwise. |
test/ext/differentiation_interface/differentiation_interface.jl
Outdated
Show resolved
Hide resolved
|
Very cool that you found a way to do this without touching DI internals @sunxd3, congrats and thanks for the effort!!! |
It might be, but whatever it is, it's fixed in 1.11. Curious what it is, though. My guess is that type inference failed, produce I would prefer to stop here and fix DI integration in another PR. |
|
Thanks @gdalle, unfortunately, the DI test doesn't work yet. I wanted to keep it simple, and use Mooncake internal functions to make minimal f-o-r work. I might start DI integration in another PR, I don't think we would need extension, but def need more investigation. |
|
I agree, 2nd order should work inside of Mooncake instead of requiring DI hacks. I just meant that DI has more possible tests to run than what is currently inside your integration tests |
Technici4n
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really cool! I will have a deeper look later this week. As always, I am a bit suspicious of new rules that need to be added. 😉
src/interpreter/forward_mode.jl
Outdated
| # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater | ||
| # than the current world age. | ||
| if Base.get_world_counter() > interp.world | ||
| # than the current world age. Exception: MistyClosure uses its own world age. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could maybe add a kwarg to disable the check? Special-casing MC here is a bit strange.
src/rrules/foreigncall.jl
Outdated
| @zero_derivative MinimalCtx Tuple{typeof(Base.has_free_typevars),Any} | ||
|
|
||
| # jl_genericmemory_owner is used by Base.dataids to determine memory aliasing. | ||
| # The result is just a pointer for aliasing detection, so it has zero derivative. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use @zero_derivative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that @zero_derivative needs the function name to the first argument. _foreigncall_ is special case because because the first argument is _foreigncall_ not the ccall name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/chalk-lab/Mooncake.jl/actions/runs/20058347212/job/57528828789?pr=878 interesting only fails on 1.11 (I removed the rule temporarily to see where it traps)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a bit more digging. This turns out to be a pretty interesting hidden issue.
The question to ask here is why
Mooncake.jl/src/rrules/misc.jl
Lines 37 to 41 in b7413d5
| @static if VERSION >= v"1.11-" | |
| @zero_derivative MinimalCtx Tuple{typeof(Random.hash_seed),Vararg} | |
| @zero_derivative MinimalCtx Tuple{typeof(Base.dataids),Memory} | |
| end | |
The answer is that the call is inlined away, and that's why jl_genericmemory_owner is exposed (1.12 doesn't use jl_genericmemory_owner anymore, so this test doesn't error on 1.12).
To verify, with julia +1.11
using Mooncake
using Mooncake: CC, get_interpreter, BugPatchInterpreter, lookup_ir, optimise_ir!, ForwardMode
f(x::Vector{Float64}) = Base.dataids(x.ref.mem)
interp = get_interpreter(ForwardMode)
# IR from lookup_ir - dataids is NOT inlined (MooncakeInterpreter prevents it)
ir, _ = lookup_ir(interp, Tuple{typeof(f), Vector{Float64}})
# IR after optimise_ir! - dataids IS inlined, exposing jl_genericmemory_owner foreigncall
ir_opt = optimise_ir!(CC.copy(ir); do_inline=true)gives
julia> ir # dataids is NOT inlined
1 ─ %1 = Base.getfield(_2, :ref)::MemoryRef{Float64}
│ %2 = Base.getfield(%1, :mem)::Memory{Float64}
│ %3 = invoke Base.dataids(%2::Memory{Float64})::Tuple{UInt64} # <-- stays as invoke
└── return %3
julia> ir_opt # dataids IS inlined, exposing foreigncall
1 ─ %1 = Base.getfield(_2, :ref)::MemoryRef{Float64}
│ %2 = Base.getfield(%1, :mem)::Memory{Float64}
│ %3 = foreigncall(:jl_genericmemory_owner, ...)::Any # <-- inlined!
│ %4 = (%3 isa Memory{Float64})::Bool
└── goto #3 if not %4
2 ─ %6 = π (%3, Memory{Float64})
└── goto #4
3 ─ nothing
4 ┄ %9 = φ (#2 => %6, #3 => %2)::Memory{Float64}
│ %10 = Base.getfield(%9, :ptr)::Ptr{Nothing}
│ %11 = Base.bitcast(Ptr{Float64}, %10)::Ptr{Float64}
│ %12 = Core.bitcast(Core.UInt, %11)::UInt64
│ %13 = Core.tuple(%12)::Tuple{UInt64}
└── return %13Why This Only Affects Forward-over-Reverse
When building rules, lookup_ir uses MooncakeInterpreter which prevents inlining primitives. But then optimise_ir! uses BugPatchInterpreter which has no such protection—it inlines everything including rrule!!(dataids, ...) down to the raw foreigncall.
In pure reverse mode, rule(...) executes this optimized IR at runtime. The foreigncall just runs and returns a value.
In forward-over-reverse, build_frule calls lookup_ir(MistyClosure) which returns the already-optimized IR containing the foreigncall. Now we try to differentiate through this IR, which requires an frule!! for the foreigncall—but none exists.
src/interpreter/reverse_mode.jl
Outdated
| # Forward-mode primitive for _build_rule! on LazyDerivedRule. | ||
| # This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter. | ||
| # The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building. | ||
| # Only primitive in ForwardMode - reverse mode uses derived rule. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is said derived rule working? If not it's likely better to define a rule that throws rather than letting Mooncake try to differentiate through it and fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a rrule saying we don't support reverse-o-r yet
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
@sunxd3 That’s possible. For the sake of thoroughness, let’s get to the bottom of this. For example, if your guess above is correct, we ought to see Union{} returned by frule_type for the MWE that’s causing issues. |
Signed-off-by: Hong Ge <[email protected]>
|
On the With the following code example, using Core: ReturnNode
using Core.Compiler
const CC = Core.Compiler
using Mooncake: build_rrule, lookup_ir, get_interpreter, ReverseMode
f(x) = sum(x .* x)
x = [1.0, 2.0]
tt = Tuple{typeof(f),typeof(x)}
interp = get_interpreter(ReverseMode)
julia_ir, _ = lookup_ir(interp, tt)
rule = build_rrule(f, x)
interp = get_interpreter(ReverseMode)
mooncake_ir, _ = lookup_ir(interp, rule.fwds_oc)
for i in 1:length(mooncake_ir.stmts)
mooncake_ir.stmts[i][:type] === Union{} || continue
stmt = @static if VERSION >= v"1.11-"
mooncake_ir.stmts[i][:stmt]
else
mooncake_ir.stmts[i][:inst]
end
typ = mooncake_ir.stmts[i][:type]
println("SSA $i :: $typ = $(typeof(stmt).name.name) - $stmt")
endOn Julia 1.10, this returns (click to expand)SSA 124 :: Union{} = GotoIfNot - goto %44 if not %123
SSA 127 :: Union{} = GotoNode - goto %47
SSA 129 :: Union{} = GotoIfNot - goto %46 if not %128
SSA 132 :: Union{} = GotoNode - goto %47
SSA 134 :: Union{} = GotoNode - goto %47
SSA 144 :: Union{} = Expr - (Mooncake.rrule!!)(...)
SSA 477 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 504 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 544 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 600 :: Union{} = Expr - invoke MistyClosure(throwdm)(...)
SSA 602 :: Union{} = PhiNode - φ () # ← ORPHANED!
SSA 604 :: Union{} = Expr - invoke Mooncake._build_rule!(...)
SSA 606 :: Union{} = PhiNode - φ () # ← ORPHANED!
SSA 607 :: Union{} = Expr - (getfield)(%606, 1) # ← USES Union{} phi!
SSA 608 :: Union{} = Expr - (getfield)(%606, 2) # ← USES Union{} phi!
SSA 609 :: Union{} = Expr - (tuple)(%608)
SSA 610 :: Union{} = Expr - (push!)(%9, %609)On Julia 1.12, this returns (click to expand)SSA 45 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 46 :: Union{} = ReturnNode - unreachable # ← CLEAN!
SSA 79 :: Union{} = Expr - invoke Base.throw_boundserror(...)
SSA 80 :: Union{} = ReturnNode - unreachable # ← CLEAN!
# ... (pattern continues: throw call → unreachable, no orphaned phi nodes)
SSA 969 :: Union{} = Expr - invoke MistyClosure(throwdm)(...)
SSA 970 :: Union{} = ReturnNode - unreachable # ← CLEAN!
SSA 971 :: Union{} = Expr - invoke Mooncake._build_rule!(...)
SSA 972 :: Union{} = ReturnNode - unreachable # ← CLEAN!
# ... (61 total Union{} statements, all are Expr or ReturnNode)The CulpritThe orphaned When Mooncake processes sig_types = [type_of(%606), type_of(1)] # = [Union{}, Int]
sig = Tuple{Union{}, Int} # ← ERROR on Julia 1.10+!I am still trying to nail down which julia compiler PRs fix these. The fix should be applied to Why Other
|
|
Repost #878 (comment) A proper fix might be adding a new interpreter that combine both interpreters. |
Ref: #826
This PR aims to enable forward-over-reverse for second order derivatives. Many of the necessary plumbings are already implemented before this PR. This PR smooths out the remaining edges.
Fixes:
zero_tangent(pb_mc)was called.Union{}type issue (this is somewhat 1.10 specific)Union{}can showup when type inference fails, the approach here is to replace them withAnyCore.memorynewLazyDerivedRule: avoid diff throughget_interpreter(jl_get_world_counter)jl_genericmemory_ownerfor 1.11jl_alloc_array_1d/2d/3dWhat doesn't work yet (hangs)
DI.hessianwithSecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))Closes #632