@@ -297,22 +297,20 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
297297 // query points to packed QKV, K and V are nullptr since they're not needed
298298 return ApplyFlashAttention (query, nullptr , nullptr , attention_bias, output, past_key, present_key, past_value,
299299 present_value, parameters, context, seqlen_k, cos_cache, sin_cache);
300- } else {
301- // Fused: splitQKV + rotary QK
302- qSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.hidden_size_ }));
303- kSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
304- vSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
305- ORT_RETURN_IF_ERROR (RunSplitPackedQKVWithRotaryEmbedding (context, parameters,
306- query, seqlen_k,
307- cos_cache, sin_cache,
308- &qSplit, &kSplit , &vSplit));
309- key = &kSplit ;
310- value = &vSplit;
311300 }
312-
301+ // Fused: splitQKV + rotary QK
302+ qSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.hidden_size_ }));
303+ kSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
304+ vSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
305+ ORT_RETURN_IF_ERROR (RunSplitPackedQKVWithRotaryEmbedding (context, parameters,
306+ query, seqlen_k,
307+ cos_cache, sin_cache,
308+ &qSplit, &kSplit , &vSplit));
313309 parameters.is_packed_qkv_ = false ;
314310 parameters.qkv_format_ = Q_K_V_BSNH;
315311 query = &qSplit;
312+ key = &kSplit ;
313+ value = &vSplit;
316314 } else {
317315 if (parameters.is_packed_qkv_ ) {
318316 // splitQKV
0 commit comments