|
| 1 | +# Softmax example - Julia port of cuTile Python TileGym softmax |
| 2 | +# |
| 3 | +# Column-major layout: input is (N, M) where N (softmax dimension) is contiguous. |
| 4 | +# Two strategies: |
| 5 | +# 1. TMA: loads entire column in one tile (small-to-medium N, TILE_SIZE >= N) |
| 6 | +# 2. Chunked: 3-pass with gather/scatter (large N, arbitrary TILE_SIZE) |
| 7 | +# |
| 8 | +# SPDX-License-Identifier: Apache-2.0 |
| 9 | + |
| 10 | +using CUDA |
| 11 | +import cuTile as ct |
| 12 | + |
| 13 | +#============================================================================= |
| 14 | + TMA Softmax Kernel (single-tile per column, persistent scheduling) |
| 15 | +=============================================================================# |
| 16 | + |
| 17 | +function softmax_tma_kernel(output::ct.TileArray{T,2}, input::ct.TileArray{T,2}, |
| 18 | + TILE_SIZE::Int) where {T} |
| 19 | + ct.@compiler_options occupancy=2 |
| 20 | + |
| 21 | + pid = ct.bid(1) |
| 22 | + num_programs = ct.num_blocks(1) |
| 23 | + M = size(input, 2) |
| 24 | + |
| 25 | + col_idx = pid |
| 26 | + while col_idx <= M |
| 27 | + col = ct.load(input; index=(Int32(1), col_idx), shape=(TILE_SIZE, 1), |
| 28 | + padding_mode=ct.PaddingMode.NegInf) |
| 29 | + col = convert(ct.Tile{Float32}, col) |
| 30 | + |
| 31 | + col_max = maximum(col; dims=1) |
| 32 | + numerator = exp.(col .- col_max) |
| 33 | + denominator = sum(numerator; dims=1) |
| 34 | + softmax_output = numerator ./ denominator |
| 35 | + |
| 36 | + ct.store(output; index=(Int32(1), col_idx), |
| 37 | + tile=convert(ct.Tile{T}, softmax_output)) |
| 38 | + col_idx += num_programs |
| 39 | + end |
| 40 | + return |
| 41 | +end |
| 42 | + |
| 43 | +#============================================================================= |
| 44 | + Chunked Softmax Kernel (3-pass with gather/scatter, persistent scheduling) |
| 45 | +=============================================================================# |
| 46 | + |
| 47 | +function softmax_chunked_kernel(output::ct.TileArray{T,2}, input::ct.TileArray{T,2}, |
| 48 | + n_elems::Int, TILE_SIZE::Int) where {T} |
| 49 | + ct.@compiler_options occupancy=4 |
| 50 | + |
| 51 | + pid = ct.bid(1) |
| 52 | + num_programs = ct.num_blocks(1) |
| 53 | + M = size(input, 2) |
| 54 | + num_chunks = (n_elems + TILE_SIZE - Int32(1)) ÷ Int32(TILE_SIZE) |
| 55 | + row_offsets_base = ct.arange(TILE_SIZE) |
| 56 | + |
| 57 | + col_idx = pid |
| 58 | + while col_idx <= M |
| 59 | + col_tile = ct.Tile(col_idx) |
| 60 | + row_max = fill(-Inf32, (1,)) |
| 61 | + denominator = zeros(Float32, TILE_SIZE) |
| 62 | + |
| 63 | + # Pass 1: Find maximum across all chunks |
| 64 | + for chunk_idx in Int32(0):num_chunks - Int32(1) |
| 65 | + row_indices = ct.broadcast_to(ct.Tile(chunk_idx * Int32(TILE_SIZE)), (TILE_SIZE,)) .+ row_offsets_base |
| 66 | + chunk = ct.gather(input, (row_indices, col_tile); |
| 67 | + check_bounds=true, padding_value=T(-Inf)) |
| 68 | + chunk = convert(ct.Tile{Float32}, chunk) |
| 69 | + chunk_max = maximum(chunk) |
| 70 | + row_max = max.(row_max, ct.Tile(chunk_max)) |
| 71 | + end |
| 72 | + |
| 73 | + # Pass 2: Compute denominator (sum of all exp values) |
| 74 | + for chunk_idx in Int32(0):num_chunks - Int32(1) |
| 75 | + row_indices = ct.broadcast_to(ct.Tile(chunk_idx * Int32(TILE_SIZE)), (TILE_SIZE,)) .+ row_offsets_base |
| 76 | + chunk = ct.gather(input, (row_indices, col_tile); |
| 77 | + check_bounds=true, padding_value=T(-Inf)) |
| 78 | + chunk = convert(ct.Tile{Float32}, chunk) |
| 79 | + denominator = denominator .+ exp.(chunk .- row_max) |
| 80 | + end |
| 81 | + denom_sum = ct.Tile(sum(denominator)) |
| 82 | + |
| 83 | + # Pass 3: Compute final softmax and scatter |
| 84 | + for chunk_idx in Int32(0):num_chunks - Int32(1) |
| 85 | + row_indices = ct.broadcast_to(ct.Tile(chunk_idx * Int32(TILE_SIZE)), (TILE_SIZE,)) .+ row_offsets_base |
| 86 | + chunk = ct.gather(input, (row_indices, col_tile); |
| 87 | + check_bounds=true, padding_value=T(-Inf)) |
| 88 | + chunk = convert(ct.Tile{Float32}, chunk) |
| 89 | + softmax_output = exp.(chunk .- row_max) ./ denom_sum |
| 90 | + ct.scatter(output, (row_indices, col_tile), convert(ct.Tile{T}, softmax_output); |
| 91 | + check_bounds=true) |
| 92 | + end |
| 93 | + |
| 94 | + col_idx += num_programs |
| 95 | + end |
| 96 | + return |
| 97 | +end |
| 98 | + |
| 99 | + |
| 100 | +#============================================================================= |
| 101 | + Example harness |
| 102 | +=============================================================================# |
| 103 | + |
| 104 | +function next_power_of_2(n::Int) |
| 105 | + n <= 0 && return 1 |
| 106 | + p = 1 |
| 107 | + while p < n |
| 108 | + p <<= 1 |
| 109 | + end |
| 110 | + return p |
| 111 | +end |
| 112 | + |
| 113 | +function prepare(; benchmark::Bool=false, |
| 114 | + M::Int=benchmark ? 4096 : 256, |
| 115 | + N::Int=benchmark ? 4096 : 256, |
| 116 | + T::DataType=Float32) |
| 117 | + # (N, M) layout: softmax dimension N is contiguous in column-major |
| 118 | + input = CUDA.randn(T, N, M) |
| 119 | + return (; |
| 120 | + input, |
| 121 | + output_tma = similar(input), |
| 122 | + output_chunked = similar(input), |
| 123 | + M, N |
| 124 | + ) |
| 125 | +end |
| 126 | + |
| 127 | +function run(data; tile_tma::Int=next_power_of_2(data.N), |
| 128 | + tile_chunked::Int=1024, |
| 129 | + nruns::Int=1, warmup::Int=0) |
| 130 | + (; input, output_tma, output_chunked, M, N) = data |
| 131 | + |
| 132 | + function run_tma() |
| 133 | + ct.launch(softmax_tma_kernel, M, output_tma, input, ct.Constant(tile_tma)) |
| 134 | + end |
| 135 | + |
| 136 | + function run_chunked() |
| 137 | + ct.launch(softmax_chunked_kernel, M, output_chunked, input, |
| 138 | + ct.Constant(N), ct.Constant(tile_chunked)) |
| 139 | + end |
| 140 | + |
| 141 | + # Warmup |
| 142 | + CUDA.@sync for _ in 1:warmup |
| 143 | + run_tma() |
| 144 | + run_chunked() |
| 145 | + end |
| 146 | + |
| 147 | + # Timed TMA runs |
| 148 | + times_tma = Float64[] |
| 149 | + for _ in 1:nruns |
| 150 | + t = CUDA.@elapsed run_tma() |
| 151 | + push!(times_tma, t * 1000) |
| 152 | + end |
| 153 | + |
| 154 | + # Timed chunked runs |
| 155 | + times_chunked = Float64[] |
| 156 | + for _ in 1:nruns |
| 157 | + t = CUDA.@elapsed run_chunked() |
| 158 | + push!(times_chunked, t * 1000) |
| 159 | + end |
| 160 | + |
| 161 | + return (; output_tma, output_chunked, |
| 162 | + times=Dict("cuTile TMA" => times_tma, "cuTile Chunked" => times_chunked)) |
| 163 | +end |
| 164 | + |
| 165 | +function verify(data, result) |
| 166 | + M, N = data.M, data.N |
| 167 | + x = Array(data.input) # (N, M) |
| 168 | + for label in (:output_tma, :output_chunked) |
| 169 | + out = Array(getproperty(result, label)) |
| 170 | + for j in 1:M |
| 171 | + col = x[:, j] |
| 172 | + col_max = maximum(col) |
| 173 | + exps = exp.(col .- col_max) |
| 174 | + expected = exps ./ sum(exps) |
| 175 | + @assert isapprox(out[:, j], expected; atol=1e-5, rtol=1e-4) "$label column $j mismatch" |
| 176 | + end |
| 177 | + end |
| 178 | +end |
| 179 | + |
| 180 | +function metric(data) |
| 181 | + MN = data.M * data.N * sizeof(Float32) |
| 182 | + return Dict( |
| 183 | + # TMA: 1 read + 1 write |
| 184 | + "cuTile TMA" => (2 * MN, "GB/s"), |
| 185 | + # Chunked: 3 reads (gather per pass) + 1 write (scatter) |
| 186 | + "cuTile Chunked" => (4 * MN, "GB/s"), |
| 187 | + ) |
| 188 | +end |
| 189 | + |
| 190 | + |
| 191 | +#============================================================================= |
| 192 | + Reference implementations for benchmarking |
| 193 | +=============================================================================# |
| 194 | + |
| 195 | +function run_others(data; nruns::Int=1, warmup::Int=0) |
| 196 | + (; input) = data |
| 197 | + results = Dict{String, Vector{Float64}}() |
| 198 | + |
| 199 | + # GPUArrays softmax via broadcasting |
| 200 | + out = similar(input) |
| 201 | + function gpu_softmax!() |
| 202 | + col_max = maximum(input; dims=1) |
| 203 | + exps = exp.(input .- col_max) |
| 204 | + out .= exps ./ sum(exps; dims=1) |
| 205 | + end |
| 206 | + |
| 207 | + CUDA.@sync for _ in 1:warmup |
| 208 | + gpu_softmax!() |
| 209 | + end |
| 210 | + times = Float64[] |
| 211 | + for _ in 1:nruns |
| 212 | + t = CUDA.@elapsed gpu_softmax!() |
| 213 | + push!(times, t * 1000) |
| 214 | + end |
| 215 | + results["GPUArrays"] = times |
| 216 | + |
| 217 | + return results |
| 218 | +end |
| 219 | + |
| 220 | + |
| 221 | +#============================================================================= |
| 222 | + Main |
| 223 | +=============================================================================# |
| 224 | + |
| 225 | +function test_softmax(M, N; tile_tma::Int=next_power_of_2(N), tile_chunked::Int=1024, name=nothing) |
| 226 | + name = something(name, "softmax ($M x $N), tma_tile=$tile_tma, chunked_tile=$tile_chunked") |
| 227 | + println("--- $name ---") |
| 228 | + data = prepare(; M, N) |
| 229 | + result = run(data; tile_tma, tile_chunked) |
| 230 | + verify(data, result) |
| 231 | + println(" tma passed, chunked passed") |
| 232 | +end |
| 233 | + |
| 234 | +function main() |
| 235 | + println("--- cuTile Softmax Examples ---\n") |
| 236 | + |
| 237 | + test_softmax(256, 256) |
| 238 | + test_softmax(1024, 1024) |
| 239 | + test_softmax(4096, 4096) |
| 240 | + |
| 241 | + println("\n--- All softmax examples completed ---") |
| 242 | +end |
| 243 | + |
| 244 | +isinteractive() || main() |
0 commit comments