@@ -3,13 +3,15 @@ module MooncakeCUDAExt
33using LinearAlgebra, Random, Mooncake
44
55using Base: IEEEFloat
6- using CUDA: CuArray, cu
6+ using CUDA: CuArray
77
88import Mooncake:
99 MinimalCtx,
1010 rrule!!,
1111 @is_primitive ,
1212 tangent_type,
13+ fdata_type,
14+ rdata_type,
1315 primal,
1416 tangent,
1517 zero_tangent_internal,
@@ -33,35 +35,95 @@ import Mooncake.TestUtils:
3335 populate_address_map_internal, AddressMap, __increment_should_allocate
3436
3537const CuFloatArray = CuArray{<: IEEEFloat }
38+ const CuComplexArray = CuArray{<: Complex{<:IEEEFloat} }
3639
3740# Tell Mooncake.jl how to handle CuArrays.
3841
39- Mooncake. @foldable tangent_type (:: Type{P} ) where {P<: CuFloatArray } = P
42+ Mooncake. @foldable tangent_type (:: Type{<:CuArray{P,N,M}} ) where {P<: Union{Complex{<:IEEEFloat},IEEEFloat} ,N,M} = CuArray{
43+ tangent_type (P),N,M
44+ }
45+
46+ Mooncake. @foldable fdata_type (:: Type{CuArray{P,N,M}} ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M} = CuArray{
47+ P,N,M
48+ }
49+
50+ Mooncake. @foldable rdata_type (
51+ :: Type{<:CuArray{P,N,M}}
52+ ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M} = Mooncake. NoRData
53+
4054function zero_tangent_internal (x:: CuFloatArray , dict:: MaybeCache )
4155 haskey (dict, x) && return dict[x]:: tangent_type (typeof (x))
4256 t = zero (x)
4357 dict[x] = t
4458 return t
4559end
46- function randn_tangent_internal (rng:: AbstractRNG , x:: CuFloatArray , dict:: MaybeCache )
60+ function zero_tangent_internal (x:: CuArray{T} , dict:: MaybeCache ) where {T<: Complex }
61+ haskey (dict, x) && return dict[x]:: tangent_type (typeof (x))
62+ t = tangent_type (typeof (x))(undef, size (x))
63+ t_ = reinterpret (T, t)
64+ t_ .= zero (T)
65+ dict[x] = t
66+ return t
67+ end
68+ function randn_tangent_internal (
69+ rng:: AbstractRNG , x:: CuArray{T} , dict:: MaybeCache
70+ ) where {T<: IEEEFloat }
71+ haskey (dict, x) && return dict[x]:: tangent_type (typeof (x))
72+ t = CuArray (randn (rng, T, size (x)... ))
73+ dict[x] = t
74+ return t
75+ end
76+ function randn_tangent_internal (
77+ rng:: AbstractRNG , x:: CuArray{T} , dict:: MaybeCache
78+ ) where {T<: Complex }
4779 haskey (dict, x) && return dict[x]:: tangent_type (typeof (x))
48- t = cu (randn (rng, Float32, size (x)... ))
80+ t = tangent_type (typeof (x))(undef, size (x))
81+ t_ = reinterpret (T, t)
82+ th = randn (rng, T, size (x)... )
83+ t_ .= CuArray (th)
4984 dict[x] = t
5085 return t
5186end
5287function TestUtils. has_equal_data_internal (
5388 x:: P , y:: P , equal_undefs:: Bool , d:: Dict{Tuple{UInt,UInt},Bool}
54- ) where {P<: CuFloatArray }
89+ ) where {P<: Union{ CuFloatArray,CuComplexArray} }
5590 return isapprox (x, y)
5691end
57- function increment_internal!! (c:: IncCache , x:: P , y:: P ) where {P<: CuFloatArray }
92+ function TestUtils. has_equal_data_internal (
93+ x:: CuArray{P,N,M} , y:: CuArray{P,N,M} , equal_undefs:: Bool , d:: Dict{Tuple{UInt,UInt},Bool}
94+ ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M}
95+ x_ = reinterpret (Complex{T}, x)
96+ y_ = reinterpret (Complex{T}, y)
97+ return isapprox (x_, y_)
98+ end
99+ function increment_internal!! (
100+ c:: IncCache , x:: CuArray{P,N,M} , y:: CuArray{P,N,M}
101+ ) where {P<: IEEEFloat ,N,M}
58102 (x === y || haskey (c, x)) && return x
59103 c[x] = true
60104 x .+ = y
61105 return x
62106end
107+ function increment_internal!! (
108+ c:: IncCache , x:: CuArray{P,N,M} , y:: CuArray{P,N,M}
109+ ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M}
110+ (x === y || haskey (c, x)) && return x
111+ c[x] = true
112+ x_ = reinterpret (Complex{T}, x)
113+ y_ = reinterpret (Complex{T}, y)
114+ x_ .+ = y_
115+ return x
116+ end
63117__increment_should_allocate (:: Type{<:CuFloatArray} ) = true
64118set_to_zero_internal!! (:: Mooncake.SetToZeroCache , x:: CuFloatArray ) = x .= 0
119+ function set_to_zero_internal!! (
120+ :: Mooncake.SetToZeroCache , x:: CuArray{Mooncake.Tangent{@NamedTuple{re::T,im::T}},N,M}
121+ ) where {T<: IEEEFloat ,N,M}
122+ x_ = reinterpret (Complex{T}, x)
123+ x_ .= zero (Complex{T})
124+ return x
125+ end
126+
65127function _add_to_primal_internal (
66128 c:: MaybeCache , x:: P , y:: P , unsafe:: Bool
67129) where {P<: CuFloatArray }
@@ -71,24 +133,61 @@ function _add_to_primal_internal(
71133 c[(x, y, unsafe)] = x′
72134 return x′
73135end
136+ function _add_to_primal_internal (
137+ c:: MaybeCache , x:: P , y:: TP , unsafe:: Bool
138+ ) where {P<: CuComplexArray ,TP}
139+ key = (x, y, unsafe)
140+ haskey (c, key) && return c[key]:: P
141+ x′ = x + reinterpret (eltype (x), y)
142+ c[(x, y, unsafe)] = x′
143+ return x′
144+ end
74145function _diff_internal (c:: MaybeCache , x:: P , y:: P ) where {P<: CuFloatArray }
75146 key = (x, y)
76147 haskey (c, key) && return c[key]:: tangent_type (P)
77148 t = x - y
78149 c[key] = t
79150 return t
80151end
152+ function _diff_internal (c:: MaybeCache , x:: P , y:: P ) where {P<: CuComplexArray }
153+ key = (x, y)
154+ haskey (c, key) && return c[key]:: tangent_type (P)
155+ t = tangent_type (P)(undef, size (x))
156+ t_ = reinterpret (eltype (x), t)
157+ @. t_ = x - y
158+ c[key] = t
159+ return t
160+ end
81161function _dot_internal (c:: MaybeCache , x:: P , y:: P ) where {P<: CuFloatArray }
82162 key = (x, y)
83163 haskey (c, key) && return c[key]:: Float64
84164 return Float64 (dot (x, y))
85165end
166+ function _dot_internal (
167+ c:: MaybeCache , x:: CuArray{P} , y:: CuArray{P}
168+ ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} }
169+ key = (x, y)
170+ haskey (c, key) && return c[key]:: Float64
171+ x_ = reinterpret (Complex{T}, x)
172+ y_ = reinterpret (Complex{T}, y)
173+ return Float64 (real (dot (x_, y_)))
174+ end
86175function _scale_internal (c:: MaybeCache , x:: Float64 , y:: P ) where {T<: IEEEFloat ,P<: CuArray{T} }
87176 haskey (c, y) && return c[y]:: P
88177 t′ = T (x) * y
89178 c[y] = t′
90179 return t′
91180end
181+ function _scale_internal (
182+ c:: MaybeCache , x:: Float64 , y:: CuArray{P,N,M}
183+ ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M}
184+ haskey (c, y) && return c[y]:: CuArray{P,N,M}
185+ t′ = copy (y)
186+ t′_ = reinterpret (Complex{T}, t′)
187+ t′_ .*= T (x)
188+ c[y] = t′
189+ return t′
190+ end
92191function populate_address_map_internal (m:: AddressMap , p:: CuArray , t:: CuArray )
93192 k = pointer_from_objref (p)
94193 v = pointer_from_objref (t)
@@ -102,8 +201,16 @@ function Mooncake.__verify_fdata_value(::IdDict{Any,Nothing}, p::CuArray, f::CuA
102201 end
103202 return nothing
104203end
105- Mooncake. @foldable tangent_type (:: Type{P} , :: Type{NoRData} ) where {P<: CuArray } = P
106- tangent (p:: CuArray , :: NoRData ) = p
204+ Mooncake. @foldable tangent_type (:: Type{P} , :: Type{NoRData} ) where {P<: CuFloatArray } = P
205+ Mooncake. @foldable tangent_type (:: Type{CuArray{P,N,M}} , :: Type{NoRData} ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M} = CuArray{
206+ P,N,M
207+ }
208+ tangent (p:: CuFloatArray , :: NoRData ) = p
209+ function tangent (
210+ p:: CuArray{P,N,M} , :: NoRData
211+ ) where {T<: IEEEFloat ,P<: Mooncake.Tangent{@NamedTuple{re::T,im::T}} ,N,M}
212+ p
213+ end
107214
108215to_cr_tangent (x:: CuFloatArray ) = x
109216function increment_and_get_rdata! (f:: T , :: NoRData , t:: T ) where {T<: CuFloatArray }
@@ -120,5 +227,13 @@ function rrule!!(
120227 _dims = map (primal, dims)
121228 return CoDual (P (undef, _dims), P (undef, _dims)), NoPullback (p, init, dims... )
122229end
230+ function rrule!! (
231+ p:: CoDual{Type{P}} , init:: CoDual{UndefInitializer} , dims:: CoDual{Int} ...
232+ ) where {P<: CuComplexArray }
233+ _dims = map (primal, dims)
234+ return (
235+ CoDual (P (undef, _dims), tangent_type (P)(undef, _dims)), NoPullback (p, init, dims... )
236+ )
237+ end
123238
124239end
0 commit comments