Skip to content

Commit 2dbcdc9

Browse files
authored
Make tangent of CuArray have the same type (#811)
* Make tangent of CuArray have the same type * Add fdata and rdata * Incremental * Fix rrule * Format * Comments * Two underscores
1 parent de55766 commit 2dbcdc9

File tree

2 files changed

+140
-23
lines changed

2 files changed

+140
-23
lines changed

ext/MooncakeCUDAExt.jl

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ module MooncakeCUDAExt
33
using LinearAlgebra, Random, Mooncake
44

55
using Base: IEEEFloat
6-
using CUDA: CuArray, cu
6+
using CUDA: CuArray
77

88
import Mooncake:
99
MinimalCtx,
1010
rrule!!,
1111
@is_primitive,
1212
tangent_type,
13+
fdata_type,
14+
rdata_type,
1315
primal,
1416
tangent,
1517
zero_tangent_internal,
@@ -33,35 +35,95 @@ import Mooncake.TestUtils:
3335
populate_address_map_internal, AddressMap, __increment_should_allocate
3436

3537
const CuFloatArray = CuArray{<:IEEEFloat}
38+
const CuComplexArray = CuArray{<:Complex{<:IEEEFloat}}
3639

3740
# Tell Mooncake.jl how to handle CuArrays.
3841

39-
Mooncake.@foldable tangent_type(::Type{P}) where {P<:CuFloatArray} = P
42+
Mooncake.@foldable tangent_type(::Type{<:CuArray{P,N,M}}) where {P<:Union{Complex{<:IEEEFloat},IEEEFloat},N,M} = CuArray{
43+
tangent_type(P),N,M
44+
}
45+
46+
Mooncake.@foldable fdata_type(::Type{CuArray{P,N,M}}) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M} = CuArray{
47+
P,N,M
48+
}
49+
50+
Mooncake.@foldable rdata_type(
51+
::Type{<:CuArray{P,N,M}}
52+
) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M} = Mooncake.NoRData
53+
4054
function zero_tangent_internal(x::CuFloatArray, dict::MaybeCache)
4155
haskey(dict, x) && return dict[x]::tangent_type(typeof(x))
4256
t = zero(x)
4357
dict[x] = t
4458
return t
4559
end
46-
function randn_tangent_internal(rng::AbstractRNG, x::CuFloatArray, dict::MaybeCache)
60+
function zero_tangent_internal(x::CuArray{T}, dict::MaybeCache) where {T<:Complex}
61+
haskey(dict, x) && return dict[x]::tangent_type(typeof(x))
62+
t = tangent_type(typeof(x))(undef, size(x))
63+
t_ = reinterpret(T, t)
64+
t_ .= zero(T)
65+
dict[x] = t
66+
return t
67+
end
68+
function randn_tangent_internal(
69+
rng::AbstractRNG, x::CuArray{T}, dict::MaybeCache
70+
) where {T<:IEEEFloat}
71+
haskey(dict, x) && return dict[x]::tangent_type(typeof(x))
72+
t = CuArray(randn(rng, T, size(x)...))
73+
dict[x] = t
74+
return t
75+
end
76+
function randn_tangent_internal(
77+
rng::AbstractRNG, x::CuArray{T}, dict::MaybeCache
78+
) where {T<:Complex}
4779
haskey(dict, x) && return dict[x]::tangent_type(typeof(x))
48-
t = cu(randn(rng, Float32, size(x)...))
80+
t = tangent_type(typeof(x))(undef, size(x))
81+
t_ = reinterpret(T, t)
82+
th = randn(rng, T, size(x)...)
83+
t_ .= CuArray(th)
4984
dict[x] = t
5085
return t
5186
end
5287
function TestUtils.has_equal_data_internal(
5388
x::P, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}
54-
) where {P<:CuFloatArray}
89+
) where {P<:Union{CuFloatArray,CuComplexArray}}
5590
return isapprox(x, y)
5691
end
57-
function increment_internal!!(c::IncCache, x::P, y::P) where {P<:CuFloatArray}
92+
function TestUtils.has_equal_data_internal(
93+
x::CuArray{P,N,M}, y::CuArray{P,N,M}, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}
94+
) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M}
95+
x_ = reinterpret(Complex{T}, x)
96+
y_ = reinterpret(Complex{T}, y)
97+
return isapprox(x_, y_)
98+
end
99+
function increment_internal!!(
100+
c::IncCache, x::CuArray{P,N,M}, y::CuArray{P,N,M}
101+
) where {P<:IEEEFloat,N,M}
58102
(x === y || haskey(c, x)) && return x
59103
c[x] = true
60104
x .+= y
61105
return x
62106
end
107+
function increment_internal!!(
108+
c::IncCache, x::CuArray{P,N,M}, y::CuArray{P,N,M}
109+
) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M}
110+
(x === y || haskey(c, x)) && return x
111+
c[x] = true
112+
x_ = reinterpret(Complex{T}, x)
113+
y_ = reinterpret(Complex{T}, y)
114+
x_ .+= y_
115+
return x
116+
end
63117
__increment_should_allocate(::Type{<:CuFloatArray}) = true
64118
set_to_zero_internal!!(::Mooncake.SetToZeroCache, x::CuFloatArray) = x .= 0
119+
function set_to_zero_internal!!(
120+
::Mooncake.SetToZeroCache, x::CuArray{Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M}
121+
) where {T<:IEEEFloat,N,M}
122+
x_ = reinterpret(Complex{T}, x)
123+
x_ .= zero(Complex{T})
124+
return x
125+
end
126+
65127
function _add_to_primal_internal(
66128
c::MaybeCache, x::P, y::P, unsafe::Bool
67129
) where {P<:CuFloatArray}
@@ -71,24 +133,61 @@ function _add_to_primal_internal(
71133
c[(x, y, unsafe)] = x′
72134
return x′
73135
end
136+
function _add_to_primal_internal(
137+
c::MaybeCache, x::P, y::TP, unsafe::Bool
138+
) where {P<:CuComplexArray,TP}
139+
key = (x, y, unsafe)
140+
haskey(c, key) && return c[key]::P
141+
x′ = x + reinterpret(eltype(x), y)
142+
c[(x, y, unsafe)] = x′
143+
return x′
144+
end
74145
function _diff_internal(c::MaybeCache, x::P, y::P) where {P<:CuFloatArray}
75146
key = (x, y)
76147
haskey(c, key) && return c[key]::tangent_type(P)
77148
t = x - y
78149
c[key] = t
79150
return t
80151
end
152+
function _diff_internal(c::MaybeCache, x::P, y::P) where {P<:CuComplexArray}
153+
key = (x, y)
154+
haskey(c, key) && return c[key]::tangent_type(P)
155+
t = tangent_type(P)(undef, size(x))
156+
t_ = reinterpret(eltype(x), t)
157+
@. t_ = x - y
158+
c[key] = t
159+
return t
160+
end
81161
function _dot_internal(c::MaybeCache, x::P, y::P) where {P<:CuFloatArray}
82162
key = (x, y)
83163
haskey(c, key) && return c[key]::Float64
84164
return Float64(dot(x, y))
85165
end
166+
function _dot_internal(
167+
c::MaybeCache, x::CuArray{P}, y::CuArray{P}
168+
) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}}}
169+
key = (x, y)
170+
haskey(c, key) && return c[key]::Float64
171+
x_ = reinterpret(Complex{T}, x)
172+
y_ = reinterpret(Complex{T}, y)
173+
return Float64(real(dot(x_, y_)))
174+
end
86175
function _scale_internal(c::MaybeCache, x::Float64, y::P) where {T<:IEEEFloat,P<:CuArray{T}}
87176
haskey(c, y) && return c[y]::P
88177
t′ = T(x) * y
89178
c[y] = t′
90179
return t′
91180
end
181+
function _scale_internal(
182+
c::MaybeCache, x::Float64, y::CuArray{P,N,M}
183+
) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M}
184+
haskey(c, y) && return c[y]::CuArray{P,N,M}
185+
t′ = copy(y)
186+
t′_ = reinterpret(Complex{T}, t′)
187+
t′_ .*= T(x)
188+
c[y] = t′
189+
return t′
190+
end
92191
function populate_address_map_internal(m::AddressMap, p::CuArray, t::CuArray)
93192
k = pointer_from_objref(p)
94193
v = pointer_from_objref(t)
@@ -102,8 +201,16 @@ function Mooncake.__verify_fdata_value(::IdDict{Any,Nothing}, p::CuArray, f::CuA
102201
end
103202
return nothing
104203
end
105-
Mooncake.@foldable tangent_type(::Type{P}, ::Type{NoRData}) where {P<:CuArray} = P
106-
tangent(p::CuArray, ::NoRData) = p
204+
Mooncake.@foldable tangent_type(::Type{P}, ::Type{NoRData}) where {P<:CuFloatArray} = P
205+
Mooncake.@foldable tangent_type(::Type{CuArray{P,N,M}}, ::Type{NoRData}) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M} = CuArray{
206+
P,N,M
207+
}
208+
tangent(p::CuFloatArray, ::NoRData) = p
209+
function tangent(
210+
p::CuArray{P,N,M}, ::NoRData
211+
) where {T<:IEEEFloat,P<:Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M}
212+
p
213+
end
107214

108215
to_cr_tangent(x::CuFloatArray) = x
109216
function increment_and_get_rdata!(f::T, ::NoRData, t::T) where {T<:CuFloatArray}
@@ -120,5 +227,13 @@ function rrule!!(
120227
_dims = map(primal, dims)
121228
return CoDual(P(undef, _dims), P(undef, _dims)), NoPullback(p, init, dims...)
122229
end
230+
function rrule!!(
231+
p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}...
232+
) where {P<:CuComplexArray}
233+
_dims = map(primal, dims)
234+
return (
235+
CoDual(P(undef, _dims), tangent_type(P)(undef, _dims)), NoPullback(p, init, dims...)
236+
)
237+
end
123238

124239
end

test/ext/cuda/cuda.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@ using Mooncake.TestUtils: test_tangent_interface, test_tangent_splitting, test_r
77

88
@testset "cuda" begin
99
if CUDA.functional()
10-
# Check we can operate on CuArrays.
11-
p = CuArray{Float32,2,CUDA.DeviceMemory}(undef, 8, 8)
12-
test_tangent_interface(StableRNG(123456), p; interface_only=false)
13-
test_tangent_splitting(StableRNG(123456), p)
10+
# Check we can operate on CuArrays of various element types.
11+
@testset for ET in (Float32, Float64, ComplexF32, ComplexF64)
12+
p = CuArray{ET,2,CUDA.DeviceMemory}(undef, 8, 8)
13+
test_tangent_interface(StableRNG(123456), p; interface_only=false)
14+
test_tangent_splitting(StableRNG(123456), p)
1415

15-
# Check we can instantiate a CuArray.
16-
test_rule(
17-
StableRNG(123456),
18-
CuArray{Float32,1,CUDA.DeviceMemory},
19-
undef,
20-
256;
21-
interface_only=true,
22-
is_primitive=true,
23-
debug_mode=true,
24-
mode=Mooncake.ReverseMode,
25-
)
16+
# Check we can instantiate a CuArray.
17+
test_rule(
18+
StableRNG(123456),
19+
CuArray{ET,1,CUDA.DeviceMemory},
20+
undef,
21+
256;
22+
interface_only=true,
23+
is_primitive=true,
24+
debug_mode=true,
25+
mode=Mooncake.ReverseMode,
26+
)
27+
end
2628
else
2729
println("Tests are skipped since no CUDA device was found. ")
2830
end

0 commit comments

Comments
 (0)