Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d441578
change default dumping dir for profiling to prevent profile hanging a…
yrq0208 Nov 21, 2025
c4a1f74
modify path to point to a local up to date tornado
yrq0208 Nov 21, 2025
3a2b55c
python venv
yrq0208 Nov 21, 2025
78c15d4
Merge branch 'beehive-lab:main' into RMS_opt_fusion
yrq0208 Nov 25, 2025
3167260
experimental changes trying to fuse reduction and map context in FFN …
yrq0208 Nov 27, 2025
cab3e9e
Merge branch 'beehive-lab:main' into main
yrq0208 Nov 27, 2025
48345ff
rms fuse opts
yrq0208 Nov 28, 2025
f69e9d9
remove comments, refactor host code to reflect the merge of reduction…
yrq0208 Dec 1, 2025
d57d1ce
Merge branch 'beehive-lab:main' into RMS_opt_fusion
yrq0208 Dec 1, 2025
676831e
Merge branch 'beehive-lab:main' into main
yrq0208 Dec 2, 2025
7c1a63e
fix the fused reduction kernel and refactor the host code accordingly…
yrq0208 Dec 9, 2025
639fa08
change in tornado
yrq0208 Dec 10, 2025
ffb6651
Merge branch 'main' into RMS_opt_fusion
yrq0208 Dec 10, 2025
884d197
update on reduction fuse
yrq0208 Dec 10, 2025
9336dc8
remove comments
yrq0208 Dec 10, 2025
68f7d1f
change in tornado
yrq0208 Dec 15, 2025
7b1d172
reduction fuse opt in RMS normalization layer for llama after the rec…
yrq0208 Dec 15, 2025
3f542b7
remove comments
yrq0208 Dec 15, 2025
ccbc2ea
revert unnecessary changes
yrq0208 Dec 15, 2025
b330bcf
remove external folder
yrq0208 Dec 15, 2025
e4cb5fb
revert changes
yrq0208 Dec 16, 2025
fba612d
Merge branch 'beehive-lab:main' into RMS_opt_fusion
yrq0208 Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion set_paths
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ echo "[INFO] Environment configured for LLaMA3 with TornadoVM at: $TORNADOVM_HOM
# 3. You can run LLaMA3 with GPU acceleration using TornadoVM
#
# To use this script: source ./setup_environment.sh
# or: . ./setup_environment.sh
# or: . ./setup_environment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public static void fusedRmsNormFFNGateUpQ8_0(
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
Expand Down Expand Up @@ -331,20 +332,170 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray
}

/**
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
* Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups,
* then it applies the computed normalization factor to input and weight elements.
*
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial
* sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements.
*
* @param context
* Kernel execution context
* @param output
* Array for normalized output
* Array to store partial sums and final normalization factor
* @param x
* Input array to normalize
* @param weights
* Weight values for each element
* @param temp
* Temporary array containing normalization factor at index 0
* @param size
* Number of elements to process
* @param ermsNorm
* Epsilon value squared for numerical stability
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayerFuse(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
int groupId = context.groupIdx;
int groupSize = context.localGroupSizeX;

// Allocate local memory with the provided size
float[] localX = context.allocateFloatLocalArray(localMemSize);

// Load input value and compute square
if (gid < size) {
float v = x.get(gid);
localX[lid] = v * v;
} else {
localX[lid] = 0.0f;
}

// Perform parallel reduction within the work group
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (lid < stride) {
localX[lid] += localX[lid + stride];
}
}

// Each workgroup stores its partial sum in a different location
if (lid == 0) {
// Store the partial sum from each workgroup
temp.set(groupId, localX[0]);
}

context.globalBarrier();

float localss = 0.0f;
int numGroups = (size + groupSize - 1) / groupSize;
for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups
localss += temp.get(i);
}
localss /= size;
localss += ermsNorm;
localss = 1.0f / TornadoMath.sqrt(localss);

if (gid < size) {
float in = x.get(gid);
float w = weights.get(gid);
output.set(gid, w * (localss * in));
}
}

/**
* Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups,
* then it applies the computed normalization factor to input and weight elements.
*
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial
* sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements.
*
* @param context
* Kernel execution context
* @param outputFP16
* Half float array to store partial sums and final normalization factor
* @param x
* Input values to normalize
* Input array to normalize
* @param weights
* Weight values for each element
* @param temp
* Temporary array containing normalization factor at index 0
* @param size
* Number of elements to process
* @param ermsNorm
* Epsilon value squared for numerical stability
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayerFuseFP16(KernelContext context, HalfFloatArray outputFP16, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
int groupId = context.groupIdx;
int groupSize = context.localGroupSizeX;

// Allocate local memory with the provided size
float[] localX = context.allocateFloatLocalArray(localMemSize);

// Load input value and compute square
if (gid < size) {
float v = x.get(gid);
localX[lid] = v * v;
} else {
localX[lid] = 0.0f;
}

// Perform parallel reduction within the work group
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (lid < stride) {
localX[lid] += localX[lid + stride];
}
}

// Each workgroup stores its partial sum in a different location
if (lid == 0) {
// Store the partial sum from each workgroup
temp.set(groupId, localX[0]);
}

context.globalBarrier();

float localss = 0.0f;
int numGroups = (size + groupSize - 1) / groupSize;
for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups
localss += temp.get(i);
}
localss /= size;
localss += ermsNorm;
localss = 1.0f / TornadoMath.sqrt(localss);

if (gid < size) {
float in = x.get(gid);
float w = weights.get(gid);
outputFP16.set(gid, new HalfFloat(w * (localss * in)));
}
}


/**
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* @param context Kernel execution context
* @param output Array for normalized output
* @param x Input values to normalize
* @param weights Weight values for each element
* @param temp Temporary array containing normalization factor at index 0
*/
public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) {
int gid = context.globalIdx;
Expand All @@ -355,25 +506,17 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray

/**
* Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation.
*
* <p>
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector
*
* @param destKeyCache
* Destination array for key cache
* @param srcKey
* Source keys to copy
* @param destValueCache
* Destination array for value cache
* @param srcValue
* Source values to copy
* @param positioNlayer
* Array containing current position
* @param kvDim
* Dimension of key/value vectors
* @param layer
* Current transformer layer index
* @param contextLength
* Maximum sequence length
* @param destKeyCache Destination array for key cache
* @param srcKey Source keys to copy
* @param destValueCache Destination array for value cache
* @param srcValue Source values to copy
* @param positioNlayer Array containing current position
* @param kvDim Dimension of key/value vectors
* @param layer Current transformer layer index
* @param contextLength Maximum sequence length
*/
public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {

Expand Down Expand Up @@ -463,21 +606,15 @@ public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArr
/**
* Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional
* information.
*
* <p>
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair
*
* @param context
* Kernel execution context
* @param positionHolder
* Array containing current position
* @param sq
* Query vectors to rotate
* @param sk
* Key vectors to rotate
* @param kv_dim
* Dimension of key/value vectors
* @param head_size
* Dimension of each attention head
* @param context Kernel execution context
* @param positionHolder Array containing current position
* @param sq Query vectors to rotate
* @param sk Key vectors to rotate
* @param kv_dim Dimension of key/value vectors
* @param head_size Dimension of each attention head
*/
public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) {
int i = context.globalIdx * 2;
Expand Down Expand Up @@ -552,31 +689,20 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold

/**
* Computes attention for a single head. Implements scaled dot-product attention with softmax normalization.
*
* <p>
* Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values
*
* @param allQ
* All query vectors
* @param key_cache
* Cached keys
* @param value_cache
* Cached values
* @param allXb
* Output buffer
* @param h
* Head index to process
* @param headSize
* Dimension per head
* @param kvDim
* Key/value dimension
* @param kvMul
* Key multiplier for grouped attention
* @param loff
* Layer offset in cache
* @param pos
* Current position
* @param wrapAtt
* Attention weights buffer
* @param allQ All query vectors
* @param key_cache Cached keys
* @param value_cache Cached values
* @param allXb Output buffer
* @param h Head index to process
* @param headSize Dimension per head
* @param kvDim Key/value dimension
* @param kvMul Key multiplier for grouped attention
* @param loff Layer offset in cache
* @param pos Current position
* @param wrapAtt Attention weights buffer
*/
private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos,
FloatArray wrapAtt) {
Expand Down Expand Up @@ -1117,23 +1243,16 @@ public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArr

/**
* Performs optimized matrix-vector multiplication where each work group processes one row of the matrix.
*
* <p>
* Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result
*
* @param context
* Kernel execution context
* @param x
* Input vector
* @param hb
* Output vector
* @param w
* Weight matrix (row-major)
* @param n
* Input dimension
* @param d
* Output dimension
* @param localWorkGroupSize
* Number of threads per work group
* @param context Kernel execution context
* @param x Input vector
* @param hb Output vector
* @param w Weight matrix (row-major)
* @param n Input dimension
* @param d Output dimension
* @param localWorkGroupSize Number of threads per work group
*/
public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
// One row per workgroup (not per thread)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
for (int i = 0; i < config.numberOfLayers(); i++) {
// === Attention Block ===
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
Expand Down Expand Up @@ -199,21 +198,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
// === Attention Block ===
// RMS Normalization
unifiedLayer.task("attn_rms_reduce",
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
context, state.temp, state.wrapX,
TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuseFP16,
context, state.wrapXbFP16, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp,
config.dim(), config.rmsNormEps(), state.localSize);

if (shouldUseFinalNormalization()) {
unifiedLayer.task("attn_rms_finalize",
TransformerComputeKernelsLayered::reductionFinalNormalization,
context, state.temp, config.dim(), config.rmsNormEps());
}

unifiedLayer.task("attn_rms_apply_fp16",
TransformerComputeKernels::mapContextWithQuantize,
context, state.wrapXbFP16, state.wrapX,
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);

// QKV Projection (fused)
unifiedLayer.task("qkv_projection",
TransformerComputeKernelsLayered::fusedQKVMatmulX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
// === Attention Block ===
// RMS Normalization
unifiedLayer.task("attn_rms_reduce",
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
context, state.temp, state.wrapX,
TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse,
context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp,
config.dim(), config.rmsNormEps(), state.localSize);

if (shouldUseFinalNormalization()) {
unifiedLayer.task("attn_rms_finalize",
TransformerComputeKernelsLayered::reductionFinalNormalization,
context, state.temp, config.dim(), config.rmsNormEps());
}

unifiedLayer.task("attn_rms_apply",
TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
context, state.wrapXb, state.wrapX,
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);

// QKV Projection (fused with Q8 dequantization)
unifiedLayer.task("qkv_projection",
TransformerComputeKernelsLayered::fusedQKVMatmulQ8,
Expand Down Expand Up @@ -306,7 +295,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
// --- Attention Block ---
// RMS Normalization
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
Expand Down