Skip to content

Commit 3c5740e

Browse files
authored
Add softmax example (#185)
1 parent 2c1a1a7 commit 3c5740e

File tree

5 files changed

+483
-4
lines changed

5 files changed

+483
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (`tileiras`
107107
| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 592 μs | 562 μs | OK (+5%) |
108108
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 18.8 TFLOPS | 20.3 TFLOPS | -7% |
109109
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 89.3 TFLOPS | 63.9 TFLOPS | +40%*** |
110+
| Softmax (TMA) | 4096² f32 | 806 GB/s | 838 GB/s | OK (-4%) |
111+
| Softmax (Chunked) | 4096² f32 | 1587 GB/s | 1676 GB/s | OK (-5%) |
110112

111113
\* The `pow(x, 2)``mulf(x, x)` strength reduction eliminates the expensive
112114
transcendental in the variance computation. Python still emits `pow`.

examples/benchmarks.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,14 @@ function run_benchmark(name::String)
100100
# Run cuTile
101101
result = @invokelatest mod.run(data; nruns=NRUNS, warmup=WARMUP)
102102

103-
# Extract times (handle times_fwd/times_bwd for layernorm)
103+
# Extract times from result
104104
if hasproperty(result, :times)
105-
results = Dict{String, Vector{Float64}}("cuTile" => result.times)
105+
t = result.times
106+
if t isa Dict
107+
results = t
108+
else
109+
results = Dict{String, Vector{Float64}}("cuTile" => t)
110+
end
106111
elseif hasproperty(result, :times_fwd)
107112
results = Dict{String, Vector{Float64}}(
108113
"cuTile Fwd" => result.times_fwd,

examples/benchmarks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,13 @@ def run_benchmark(name: str):
105105
# Run cuTile
106106
result = run_fn(data, nruns=NRUNS, warmup=WARMUP)
107107

108-
# Extract times (handle times_fwd/times_bwd for layernorm)
108+
# Extract times from result
109109
if "times" in result:
110-
results = {"cuTile": result["times"]}
110+
t = result["times"]
111+
if isinstance(t, dict):
112+
results = t
113+
else:
114+
results = {"cuTile": t}
111115
elif "times_fwd" in result:
112116
results = {
113117
"cuTile Fwd": result["times_fwd"],

examples/softmax.jl

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Softmax example - Julia port of cuTile Python TileGym softmax
2+
#
3+
# Column-major layout: input is (N, M) where N (softmax dimension) is contiguous.
4+
# Two strategies:
5+
# 1. TMA: loads entire column in one tile (small-to-medium N, TILE_SIZE >= N)
6+
# 2. Chunked: 3-pass with gather/scatter (large N, arbitrary TILE_SIZE)
7+
#
8+
# SPDX-License-Identifier: Apache-2.0
9+
10+
using CUDA
11+
import cuTile as ct
12+
13+
#=============================================================================
14+
TMA Softmax Kernel (single-tile per column, persistent scheduling)
15+
=============================================================================#
16+
17+
function softmax_tma_kernel(output::ct.TileArray{T,2}, input::ct.TileArray{T,2},
18+
TILE_SIZE::Int) where {T}
19+
ct.@compiler_options occupancy=2
20+
21+
pid = ct.bid(1)
22+
num_programs = ct.num_blocks(1)
23+
M = size(input, 2)
24+
25+
col_idx = pid
26+
while col_idx <= M
27+
col = ct.load(input; index=(Int32(1), col_idx), shape=(TILE_SIZE, 1),
28+
padding_mode=ct.PaddingMode.NegInf)
29+
col = convert(ct.Tile{Float32}, col)
30+
31+
col_max = maximum(col; dims=1)
32+
numerator = exp.(col .- col_max)
33+
denominator = sum(numerator; dims=1)
34+
softmax_output = numerator ./ denominator
35+
36+
ct.store(output; index=(Int32(1), col_idx),
37+
tile=convert(ct.Tile{T}, softmax_output))
38+
col_idx += num_programs
39+
end
40+
return
41+
end
42+
43+
#=============================================================================
44+
Chunked Softmax Kernel (3-pass with gather/scatter, persistent scheduling)
45+
=============================================================================#
46+
47+
function softmax_chunked_kernel(output::ct.TileArray{T,2}, input::ct.TileArray{T,2},
48+
n_elems::Int, TILE_SIZE::Int) where {T}
49+
ct.@compiler_options occupancy=4
50+
51+
pid = ct.bid(1)
52+
num_programs = ct.num_blocks(1)
53+
M = size(input, 2)
54+
num_chunks = (n_elems + TILE_SIZE - Int32(1)) ÷ Int32(TILE_SIZE)
55+
row_offsets_base = ct.arange(TILE_SIZE)
56+
57+
col_idx = pid
58+
while col_idx <= M
59+
col_tile = ct.Tile(col_idx)
60+
row_max = fill(-Inf32, (1,))
61+
denominator = zeros(Float32, TILE_SIZE)
62+
63+
# Pass 1: Find maximum across all chunks
64+
for chunk_idx in Int32(0):num_chunks - Int32(1)
65+
row_indices = ct.broadcast_to(ct.Tile(chunk_idx * Int32(TILE_SIZE)), (TILE_SIZE,)) .+ row_offsets_base
66+
chunk = ct.gather(input, (row_indices, col_tile);
67+
check_bounds=true, padding_value=T(-Inf))
68+
chunk = convert(ct.Tile{Float32}, chunk)
69+
chunk_max = maximum(chunk)
70+
row_max = max.(row_max, ct.Tile(chunk_max))
71+
end
72+
73+
# Pass 2: Compute denominator (sum of all exp values)
74+
for chunk_idx in Int32(0):num_chunks - Int32(1)
75+
row_indices = ct.broadcast_to(ct.Tile(chunk_idx * Int32(TILE_SIZE)), (TILE_SIZE,)) .+ row_offsets_base
76+
chunk = ct.gather(input, (row_indices, col_tile);
77+
check_bounds=true, padding_value=T(-Inf))
78+
chunk = convert(ct.Tile{Float32}, chunk)
79+
denominator = denominator .+ exp.(chunk .- row_max)
80+
end
81+
denom_sum = ct.Tile(sum(denominator))
82+
83+
# Pass 3: Compute final softmax and scatter
84+
for chunk_idx in Int32(0):num_chunks - Int32(1)
85+
row_indices = ct.broadcast_to(ct.Tile(chunk_idx * Int32(TILE_SIZE)), (TILE_SIZE,)) .+ row_offsets_base
86+
chunk = ct.gather(input, (row_indices, col_tile);
87+
check_bounds=true, padding_value=T(-Inf))
88+
chunk = convert(ct.Tile{Float32}, chunk)
89+
softmax_output = exp.(chunk .- row_max) ./ denom_sum
90+
ct.scatter(output, (row_indices, col_tile), convert(ct.Tile{T}, softmax_output);
91+
check_bounds=true)
92+
end
93+
94+
col_idx += num_programs
95+
end
96+
return
97+
end
98+
99+
100+
#=============================================================================
101+
Example harness
102+
=============================================================================#
103+
104+
function next_power_of_2(n::Int)
105+
n <= 0 && return 1
106+
p = 1
107+
while p < n
108+
p <<= 1
109+
end
110+
return p
111+
end
112+
113+
function prepare(; benchmark::Bool=false,
114+
M::Int=benchmark ? 4096 : 256,
115+
N::Int=benchmark ? 4096 : 256,
116+
T::DataType=Float32)
117+
# (N, M) layout: softmax dimension N is contiguous in column-major
118+
input = CUDA.randn(T, N, M)
119+
return (;
120+
input,
121+
output_tma = similar(input),
122+
output_chunked = similar(input),
123+
M, N
124+
)
125+
end
126+
127+
function run(data; tile_tma::Int=next_power_of_2(data.N),
128+
tile_chunked::Int=1024,
129+
nruns::Int=1, warmup::Int=0)
130+
(; input, output_tma, output_chunked, M, N) = data
131+
132+
function run_tma()
133+
ct.launch(softmax_tma_kernel, M, output_tma, input, ct.Constant(tile_tma))
134+
end
135+
136+
function run_chunked()
137+
ct.launch(softmax_chunked_kernel, M, output_chunked, input,
138+
ct.Constant(N), ct.Constant(tile_chunked))
139+
end
140+
141+
# Warmup
142+
CUDA.@sync for _ in 1:warmup
143+
run_tma()
144+
run_chunked()
145+
end
146+
147+
# Timed TMA runs
148+
times_tma = Float64[]
149+
for _ in 1:nruns
150+
t = CUDA.@elapsed run_tma()
151+
push!(times_tma, t * 1000)
152+
end
153+
154+
# Timed chunked runs
155+
times_chunked = Float64[]
156+
for _ in 1:nruns
157+
t = CUDA.@elapsed run_chunked()
158+
push!(times_chunked, t * 1000)
159+
end
160+
161+
return (; output_tma, output_chunked,
162+
times=Dict("cuTile TMA" => times_tma, "cuTile Chunked" => times_chunked))
163+
end
164+
165+
function verify(data, result)
166+
M, N = data.M, data.N
167+
x = Array(data.input) # (N, M)
168+
for label in (:output_tma, :output_chunked)
169+
out = Array(getproperty(result, label))
170+
for j in 1:M
171+
col = x[:, j]
172+
col_max = maximum(col)
173+
exps = exp.(col .- col_max)
174+
expected = exps ./ sum(exps)
175+
@assert isapprox(out[:, j], expected; atol=1e-5, rtol=1e-4) "$label column $j mismatch"
176+
end
177+
end
178+
end
179+
180+
function metric(data)
181+
MN = data.M * data.N * sizeof(Float32)
182+
return Dict(
183+
# TMA: 1 read + 1 write
184+
"cuTile TMA" => (2 * MN, "GB/s"),
185+
# Chunked: 3 reads (gather per pass) + 1 write (scatter)
186+
"cuTile Chunked" => (4 * MN, "GB/s"),
187+
)
188+
end
189+
190+
191+
#=============================================================================
192+
Reference implementations for benchmarking
193+
=============================================================================#
194+
195+
function run_others(data; nruns::Int=1, warmup::Int=0)
196+
(; input) = data
197+
results = Dict{String, Vector{Float64}}()
198+
199+
# GPUArrays softmax via broadcasting
200+
out = similar(input)
201+
function gpu_softmax!()
202+
col_max = maximum(input; dims=1)
203+
exps = exp.(input .- col_max)
204+
out .= exps ./ sum(exps; dims=1)
205+
end
206+
207+
CUDA.@sync for _ in 1:warmup
208+
gpu_softmax!()
209+
end
210+
times = Float64[]
211+
for _ in 1:nruns
212+
t = CUDA.@elapsed gpu_softmax!()
213+
push!(times, t * 1000)
214+
end
215+
results["GPUArrays"] = times
216+
217+
return results
218+
end
219+
220+
221+
#=============================================================================
222+
Main
223+
=============================================================================#
224+
225+
function test_softmax(M, N; tile_tma::Int=next_power_of_2(N), tile_chunked::Int=1024, name=nothing)
226+
name = something(name, "softmax ($M x $N), tma_tile=$tile_tma, chunked_tile=$tile_chunked")
227+
println("--- $name ---")
228+
data = prepare(; M, N)
229+
result = run(data; tile_tma, tile_chunked)
230+
verify(data, result)
231+
println(" tma passed, chunked passed")
232+
end
233+
234+
function main()
235+
println("--- cuTile Softmax Examples ---\n")
236+
237+
test_softmax(256, 256)
238+
test_softmax(1024, 1024)
239+
test_softmax(4096, 4096)
240+
241+
println("\n--- All softmax examples completed ---")
242+
end
243+
244+
isinteractive() || main()

0 commit comments

Comments
 (0)