Skip to content

Commit 4071f32

Browse files
authored
Add Type Checking in Forward Mode (#808)
* code and test * Revert "code and test" This reverts commit a9af7cb. * code and test * formatting * try fixing test error * remove confusing example * address Bruno's suggestions * typo * separate the two debug mode tests * remove forward debug mode for temp test * fix 1.10 segfault * use generated function to avoid ever trap into the compiler error * formatting * more formatting * add couple of test to mirror reverse mode * bump version
1 parent 36cd560 commit 4071f32

File tree

5 files changed

+285
-44
lines changed

5 files changed

+285
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.180"
4+
version = "0.4.181"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/debug_mode.jl

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,135 @@
1-
# TODO: make it non-trivial. See https://github.com/chalk-lab/Mooncake.jl/issues/672
2-
DebugFRule(rule) = rule
1+
"""
2+
DebugFRule(rule)
3+
4+
Construct a callable equivalent to `rule` but with additional type checking for forward-mode
5+
AD. Checks:
6+
- Each `Dual` argument has tangent type matching `tangent_type(typeof(primal))`
7+
- The returned `Dual` has a correctly-typed tangent
8+
- Deep structural validation (array sizes, field types, etc.)
9+
10+
Forward-mode counterpart to [`DebugRRule`](@ref).
11+
12+
*Note:* Debug mode significantly slows execution (10-100x) and should only be used for
13+
diagnosing problems, not production runs.
14+
```
15+
"""
16+
struct DebugFRule{Trule}
17+
rule::Trule
18+
end
19+
20+
# Recursively copy the wrapped rule
21+
_copy(x::P) where {P<:DebugFRule} = P(_copy(x.rule))
22+
23+
"""
24+
(rule::DebugFRule)(x::Vararg{Dual,N}) where {N}
25+
26+
Apply pre- and post-condition type checking. See [`DebugFRule`](@ref).
27+
"""
28+
@static if VERSION < v"1.11-"
29+
# On Julia 1.10, use @generated to check types at compile time, preventing the
30+
# compiler from ever seeing rule.rule(x...) with mismatched types, which would
31+
# cause a segfault (JuliaLang/julia#51016).
32+
@generated function (rule::DebugFRule{Trule})(x::Vararg{Dual,N}) where {Trule,N}
33+
# First, check tangent type consistency for all Dual inputs at compile time.
34+
# This prevents the compiler from generating code for rule.rule(x...) with
35+
# mismatched Dual types (e.g., Dual{Float64,Float32} instead of Dual{Float64,Float64}).
36+
for dt in x
37+
P = dt.parameters[1] # primal type
38+
T = dt.parameters[2] # tangent type
39+
T_expected = tangent_type(P)
40+
if T !== T_expected
41+
msg = "Error in inputs to rule with input types $(Tuple{x...})"
42+
return :(error($msg))
43+
end
44+
end
45+
46+
# Check primal types match rule signature
47+
if Trule <: DerivedFRule && isconcretetype(Trule)
48+
sig = Trule.parameters[1] # primal_sig
49+
isva = Trule.parameters[3]
50+
nargs_val = Trule.parameters[4]
51+
52+
# Extract primal types
53+
primal_types = [dt.parameters[1] for dt in x]
54+
55+
# Handle varargs unflattening
56+
if isva
57+
regular_types = primal_types[1:(nargs_val - 1)]
58+
vararg_types = primal_types[nargs_val:end]
59+
grouped_type = Tuple{vararg_types...}
60+
final_types = [regular_types..., grouped_type]
61+
else
62+
final_types = primal_types
63+
end
64+
65+
Tx = Tuple{final_types...}
66+
if !(Tx <: sig)
67+
msg = "Error in inputs to rule with input types $(Tuple{x...})"
68+
return :(error($msg))
69+
end
70+
end
71+
72+
return quote
73+
verify_dual_inputs(x)
74+
y = rule.rule(x...)
75+
verify_dual_output(x, y)
76+
return y
77+
end
78+
end
79+
else
80+
@noinline function (rule::DebugFRule)(x::Vararg{Dual,N}) where {N}
81+
try
82+
verify_args(rule.rule, x)
83+
catch
84+
error("Error in inputs to rule with input types $(_typeof(x))")
85+
end
86+
verify_dual_inputs(x)
87+
y = rule.rule(x...)
88+
verify_dual_output(x, y)
89+
return y::Dual
90+
end
91+
end
92+
93+
@noinline function verify_dual_inputs(@nospecialize(x::Tuple))
94+
try
95+
for _x in x
96+
_x isa Dual || error("Expected Dual, got $(typeof(_x))")
97+
verify_dual_value(_x)
98+
end
99+
catch e
100+
error("Error in inputs to rule with input types $(_typeof(x))")
101+
end
102+
end
103+
104+
@noinline function verify_dual_output(@nospecialize(x), @nospecialize(y))
105+
try
106+
y isa Dual || error("frule!! must return a Dual, got $(typeof(y))")
107+
verify_dual_value(y)
108+
catch e
109+
error("Error in outputs of rule with input types $(_typeof(x))")
110+
end
111+
end
112+
113+
@noinline function verify_dual_value(d::Dual{P,T}) where {P,T}
114+
# Fast path: type-level check using the Dual type parameters to enforce T == tangent_type(P)
115+
T_expected = tangent_type(P)
116+
if T !== T_expected
117+
throw(
118+
InvalidFDataException(
119+
"Dual tangent type mismatch: primal $P requires tangent type " *
120+
"$T_expected, but got $T",
121+
),
122+
)
123+
end
124+
125+
# Slow path: deep structural validation
126+
p, t = primal(d), tangent(d)
127+
# We validate fdata and rdata separately so these helpers stay in sync with reverse-mode checks.
128+
verify_fdata_value(p, fdata(t))
129+
verify_rdata_value(p, rdata(t))
130+
131+
return nothing
132+
end
3133

4134
"""
5135
DebugPullback(pb, y, x)

src/interpreter/forward_mode.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ function _copy(x::P) where {P<:DerivedFRule}
8989
return P(replace_captures(x.fwd_oc, _copy(x.fwd_oc.oc.captures)))
9090
end
9191

92+
_isva(::DerivedFRule{P,T,isva,nargs}) where {P,T,isva,nargs} = isva
93+
_nargs(::DerivedFRule{P,T,isva,nargs}) where {P,T,isva,nargs} = nargs
94+
95+
# Extends functionality defined in debug_mode.jl.
96+
function verify_args(r::DerivedFRule{sig}, x) where {sig}
97+
Tx = Tuple{
98+
map(_typeof primal, __unflatten_dual_varargs(_isva(r), x, Val(_nargs(r))))...
99+
}
100+
Tx <: sig && return nothing
101+
throw(ArgumentError("Arguments with sig $Tx do not subtype rule signature, $sig"))
102+
end
103+
92104
"""
93105
__unflatten_dual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}
94106

src/test_utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ using Mooncake:
9696
MutableTangent,
9797
frule!!,
9898
rrule!!,
99+
DebugFRule,
99100
build_rrule,
100101
tangent_type,
101102
zero_tangent,
@@ -944,8 +945,8 @@ function test_rule(
944945
frule = test_fwd ? build_frule(fwd_interp, sig; debug_mode) : missing
945946
rrule = test_rvs ? build_rrule(rvs_interp, sig; debug_mode) : missing
946947

947-
# If something is primitive, then the rule should be `rrule!!`.
948-
test_fwd && is_primitive && @test frule == frule!!
948+
# If something is primitive, then the rule should be `frule!!` or `rrule!!`.
949+
test_fwd && is_primitive && @test frule == (debug_mode ? DebugFRule(frule!!) : frule!!)
949950
test_rvs && is_primitive && @test rrule == (debug_mode ? DebugRRule(rrule!!) : rrule!!)
950951

951952
# Generate random tangents for anything that is not already a CoDual.

test/debug_mode.jl

Lines changed: 137 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,142 @@
11
@testset "debug_mode" begin
2+
@testset "reverse debug mode" begin
3+
# Unless we explicitly check that the arguments are of the type as expected by the rule,
4+
# this will segfault.
5+
@testset "argument checking" begin
6+
f = x -> 5x
7+
rule = build_rrule(f, 5.0; debug_mode=true)
8+
@test_throws ErrorException rule(zero_fcodual(f), CoDual(0.0f0, 1.0f0))
9+
end
210

3-
# Unless we explicitly check that the arguments are of the type as expected by the rule,
4-
# this will segfault.
5-
@testset "argument checking" begin
6-
f = x -> 5x
7-
rule = build_rrule(f, 5.0; debug_mode=true)
8-
@test_throws ErrorException rule(zero_fcodual(f), CoDual(0.0f0, 1.0f0))
11+
# Forwards-pass tests.
12+
x = (CoDual(sin, NoTangent()), CoDual(5.0, NoFData()))
13+
@test_throws(ErrorException, Mooncake.DebugRRule(rrule!!)(x...))
14+
x = (CoDual(sin, NoFData()), CoDual(5.0, NoFData()))
15+
@test_throws(
16+
ErrorException,
17+
Mooncake.DebugRRule((x...,) -> (CoDual(1.0, 0.0), nothing))(x...)
18+
)
19+
20+
# Basic type checking.
21+
x = (CoDual(size, NoFData()), CoDual(randn(10), randn(Float16, 11)))
22+
@test_throws ErrorException Mooncake.DebugRRule(rrule!!)(x...)
23+
24+
# Element type checking. Abstractly typed-elements prevent determining incorrectness
25+
# just by looking at the array.
26+
x = (
27+
CoDual(size, NoFData()),
28+
CoDual(Any[rand() for _ in 1:10], Any[rand(Float16) for _ in 1:10]),
29+
)
30+
@test_throws ErrorException Mooncake.DebugRRule(rrule!!)(x...)
31+
32+
# Test that bad rdata is caught as a pre-condition.
33+
y, pb!! = Mooncake.DebugRRule(rrule!!)(zero_fcodual(sin), zero_fcodual(5.0))
34+
@test_throws(InvalidRDataException, pb!!(5))
35+
36+
# Test that bad rdata is caught as a post-condition.
37+
rule_with_bad_pb(x::CoDual{Float64}) = x, dy -> (5,) # returns the wrong type
38+
y, pb!! = Mooncake.DebugRRule(rule_with_bad_pb)(zero_fcodual(5.0))
39+
@test_throws InvalidRDataException pb!!(1.0)
40+
41+
# Test that bad rdata is caught as a post-condition.
42+
rule_with_bad_pb_length(x::CoDual{Float64}) = x, dy -> (5, 5.0) # returns the wrong type
43+
y, pb!! = Mooncake.DebugRRule(rule_with_bad_pb_length)(zero_fcodual(5.0))
44+
@test_throws ErrorException pb!!(1.0)
945
end
1046

11-
# Forwards-pass tests.
12-
x = (CoDual(sin, NoTangent()), CoDual(5.0, NoFData()))
13-
@test_throws(ErrorException, Mooncake.DebugRRule(rrule!!)(x...))
14-
x = (CoDual(sin, NoFData()), CoDual(5.0, NoFData()))
15-
@test_throws(
16-
ErrorException, Mooncake.DebugRRule((x...,) -> (CoDual(1.0, 0.0), nothing))(x...)
17-
)
18-
19-
# Basic type checking.
20-
x = (CoDual(size, NoFData()), CoDual(randn(10), randn(Float16, 11)))
21-
@test_throws ErrorException Mooncake.DebugRRule(rrule!!)(x...)
22-
23-
# Element type checking. Abstractly typed-elements prevent determining incorrectness
24-
# just by looking at the array.
25-
x = (
26-
CoDual(size, NoFData()),
27-
CoDual(Any[rand() for _ in 1:10], Any[rand(Float16) for _ in 1:10]),
28-
)
29-
@test_throws ErrorException Mooncake.DebugRRule(rrule!!)(x...)
30-
31-
# Test that bad rdata is caught as a pre-condition.
32-
y, pb!! = Mooncake.DebugRRule(rrule!!)(zero_fcodual(sin), zero_fcodual(5.0))
33-
@test_throws(InvalidRDataException, pb!!(5))
34-
35-
# Test that bad rdata is caught as a post-condition.
36-
rule_with_bad_pb(x::CoDual{Float64}) = x, dy -> (5,) # returns the wrong type
37-
y, pb!! = Mooncake.DebugRRule(rule_with_bad_pb)(zero_fcodual(5.0))
38-
@test_throws InvalidRDataException pb!!(1.0)
39-
40-
# Test that bad rdata is caught as a post-condition.
41-
rule_with_bad_pb_length(x::CoDual{Float64}) = x, dy -> (5, 5.0) # returns the wrong type
42-
y, pb!! = Mooncake.DebugRRule(rule_with_bad_pb_length)(zero_fcodual(5.0))
43-
@test_throws ErrorException pb!!(1.0)
47+
@testset "forward debug mode" begin
48+
@testset "argument checking" begin
49+
f = x -> 5x
50+
rule = Mooncake.build_frule(zero_dual(f), 5.0; debug_mode=true)
51+
@test_throws ErrorException rule(
52+
zero_dual(f), Mooncake.Dual(Float32(5.0), Float32(1.0))
53+
)
54+
end
55+
56+
@testset "valid inputs pass" begin
57+
# Single argument - use Float64, not π which has NoTangent
58+
rule = Mooncake.build_frule(zero_dual(sin), 0.0; debug_mode=true)
59+
@test rule(zero_dual(sin), Mooncake.Dual(3.14, 1.0)) isa Mooncake.Dual
60+
61+
# Multiple arguments
62+
f_mul(x, y) = x * y
63+
rule = Mooncake.build_frule(zero_dual(f_mul), 2.0, 3.0; debug_mode=true)
64+
@test rule(
65+
zero_dual(f_mul), Mooncake.Dual(2.0, 1.0), Mooncake.Dual(3.0, 0.5)
66+
) isa Mooncake.Dual
67+
68+
# Arrays
69+
h(x) = sum(x)
70+
rule = Mooncake.build_frule(zero_dual(h), randn(5); debug_mode=true)
71+
@test rule(zero_dual(h), Mooncake.Dual(randn(5), randn(5))) isa Mooncake.Dual
72+
73+
# NoTangent (non-differentiable)
74+
rule = Mooncake.build_frule(zero_dual(identity), 5; debug_mode=true)
75+
@test rule(zero_dual(identity), Mooncake.Dual(5, NoTangent())) isa Mooncake.Dual
76+
end
77+
78+
@testset "size mismatch detected" begin
79+
rule = Mooncake.build_frule(zero_dual(size), randn(10); debug_mode=true)
80+
@test_throws ErrorException rule(
81+
zero_dual(size), Mooncake.Dual(randn(11), randn(10))
82+
)
83+
end
84+
85+
@testset "element type mismatch detected" begin
86+
rule = Mooncake.build_frule(zero_dual(identity), Any[1.0]; debug_mode=true)
87+
@test_throws ErrorException rule(
88+
zero_dual(identity), Mooncake.Dual(Any[1.0], Any[Float16(1.0)])
89+
)
90+
end
91+
92+
@testset "scalar type mismatch detected" begin
93+
rule = Mooncake.build_frule(zero_dual(identity), 1.0; debug_mode=true)
94+
@test_throws ErrorException rule(
95+
zero_dual(identity), Mooncake.Dual(1.0, Float32(1.0))
96+
)
97+
end
98+
99+
@testset "container type mismatch detected" begin
100+
rule = Mooncake.build_frule(zero_dual(identity), (1.0, 2.0); debug_mode=true)
101+
@test_throws ErrorException rule(
102+
zero_dual(identity), Mooncake.Dual((1.0, 2.0), [1.0, 2.0])
103+
)
104+
end
105+
106+
@testset "output tangent type mismatch detected" begin
107+
# Rule that returns wrong tangent type in output
108+
bad_rule = Mooncake.DebugFRule((x...,) -> Mooncake.Dual(1.0, Float32(0.0)))
109+
@test_throws ErrorException bad_rule(Mooncake.Dual(5.0, 1.0))
110+
end
111+
112+
@testset "error messages include type info" begin
113+
rule = Mooncake.build_frule(zero_dual(identity), [1.0]; debug_mode=true)
114+
115+
try
116+
rule(zero_dual(identity), Mooncake.Dual([1.0], [Float32(1.0)]))
117+
@test false # Expected ErrorException but none was thrown
118+
catch e
119+
msg = sprint(showerror, e)
120+
@test occursin("input types", msg)
121+
@test occursin("Float", msg) # Type info present
122+
end
123+
end
124+
125+
@testset "integration with test_rule" begin
126+
# Test basic case - test_rule expects primal functions, not Duals
127+
Mooncake.TestUtils.test_rule(
128+
sr(123456), sin, 1.0; mode=ForwardMode, debug_mode=true, perf_flag=:none
129+
)
130+
131+
# Test with array
132+
Mooncake.TestUtils.test_rule(
133+
sr(123456),
134+
sum,
135+
randn(5);
136+
mode=ForwardMode,
137+
debug_mode=true,
138+
perf_flag=:none,
139+
)
140+
end
141+
end
44142
end

0 commit comments

Comments
 (0)