@@ -16,6 +16,32 @@ namespace onnxruntime {
1616namespace contrib {
1717namespace webgpu {
1818
19+ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode (ShaderHelper& sh) const {
20+ const auto & packed_qkv = sh.AddInput (" packed_qkv" , ShaderUsage::UseUniform);
21+ const auto & seqlens = sh.AddInput (" seqlens" , ShaderUsage::UseUniform);
22+ const auto & cos_cache = sh.AddInput (" cos_cache" , ShaderUsage::UseUniform);
23+ const auto & sin_cache = sh.AddInput (" sin_cache" , ShaderUsage::UseUniform);
24+
25+ const auto & query = sh.AddOutput (" query" , ShaderUsage::UseUniform);
26+ const auto & present_key = sh.AddOutput (" present_key" , ShaderUsage::UseUniform);
27+ const auto & present_value = sh.AddOutput (" present_value" , ShaderUsage::UseUniform);
28+
29+ if (prepare_indirect_dispatch_) {
30+ sh.AddOutput (" indirect_buffer" , ShaderUsage::None);
31+ }
32+
33+ return WGSL_TEMPLATE_APPLY (sh, " bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template" ,
34+ WGSL_TEMPLATE_PARAMETER (interleaved, interleaved_),
35+ WGSL_TEMPLATE_PARAMETER (prepare_indirect_dispatch, prepare_indirect_dispatch_),
36+ WGSL_TEMPLATE_VARIABLE (cos_cache, cos_cache),
37+ WGSL_TEMPLATE_VARIABLE (packed_qkv, packed_qkv),
38+ WGSL_TEMPLATE_VARIABLE (present_key, present_key),
39+ WGSL_TEMPLATE_VARIABLE (present_value, present_value),
40+ WGSL_TEMPLATE_VARIABLE (query, query),
41+ WGSL_TEMPLATE_VARIABLE (seqlens, seqlens),
42+ WGSL_TEMPLATE_VARIABLE (sin_cache, sin_cache));
43+ }
44+
1945Status CopyKVCacheProgram::GenerateShaderCode (ShaderHelper& shader) const {
2046 // Expectations are
2147 // qkv have same number of heads and hidden dimension (head size).
@@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
351377
352378Status ApplyFlashAttention (const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
353379 Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
354- const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
380+ const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
381+ const Tensor* cos_cache, const Tensor* sin_cache) {
382+ constexpr uint32_t tile_size = 64 ;
383+
355384 // Extract present_sequence_length directly from present_key tensor shape:
356385 // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
357386 const uint32_t present_sequence_length = static_cast <uint32_t >(present_key->Shape ()[2 ]);
358387
359388 const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled ();
360389
390+ // Declare query_output at function scope to ensure it persists throughout the function
391+ Tensor query_output;
392+
393+ // Create indirect dispatch buffer if using indirect dispatch
394+ Tensor* indirect_buffer_ptr = nullptr ;
395+ Tensor indirect_buffer;
396+
397+ // Prepare indirect dispatch buffer for decode path with static KV cache
398+ const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
399+ parameters.past_present_share_buffer_ &&
400+ seqlen_k != nullptr &&
401+ context.IsGraphCaptureEnabled ();
402+ if (use_indirect_dispatch) {
403+ const TensorShape indirect_buffer_shape{3 }; // 3 uint32 values for dispatch dimensions
404+ indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
405+ indirect_buffer_ptr = &indirect_buffer;
406+ }
407+
408+ const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr );
409+
410+ if (do_rotary) {
411+ ORT_ENFORCE (parameters.is_packed_qkv_ , " Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input." );
412+ ORT_ENFORCE (parameters.past_present_share_buffer_ , " Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache." );
413+
414+ // Q points to the packed QKV tensor in this case, create query output tensor
415+ query_output = context.CreateGPUTensor (Q->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.hidden_size_ }));
416+
417+ ORT_RETURN_IF_ERROR (RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV (context, parameters,
418+ Q, seqlen_k,
419+ cos_cache, sin_cache,
420+ &query_output, present_key, present_value,
421+ indirect_buffer_ptr, tile_size));
422+ Q = &query_output;
423+ } else {
424+ ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr , indirect_buffer_ptr));
425+ }
426+
361427 if (parameters.sequence_length_ > 1 ) {
362- const uint32_t tile_size = 64 ;
363- // For encode path, use the original CopyKVCache without indirect dispatch preparation
364- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr , nullptr ));
365428 bool has_attention_bias = attention_bias != nullptr ;
366429 bool is_qualcomm = context.AdapterInfo ().vendor == std::string_view{" qualcomm" };
367430 bool is_nvidia = context.AdapterInfo ().vendor == std::string_view{" nvidia" };
@@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
406469 parameters.sequence_length_ , present_sequence_length});
407470 const TensorShape qk_shape (qk_dims);
408471 Tensor qk = context.CreateGPUTensor (Q->DataType (), qk_shape);
409- constexpr uint32_t tile_size = 64 ;
410472 const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1 ) / tile_size;
411473 const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1 ) / tile_size;
412474
413- // Determine if we should use indirect dispatch
414- const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
415- seqlen_k != nullptr &&
416- context.IsGraphCaptureEnabled ();
417-
418- // Create indirect dispatch buffer if using indirect dispatch
419- Tensor* indirect_buffer_ptr = nullptr ;
420- Tensor indirect_buffer;
421- if (use_indirect_dispatch) {
422- const TensorShape indirect_buffer_shape{3 }; // 3 uint32 values for dispatch dimensions
423- indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
424- indirect_buffer_ptr = &indirect_buffer;
425- // Use the fused CopyKVCache that also prepares the indirect dispatch buffer
426- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
427- } else {
428- // Use the original CopyKVCache without indirect dispatch preparation
429- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr ));
430- }
431-
432475 // The metadata is used to store the max and sum of each tile.
433476 const TensorShapeVector metadata_dims ({parameters.batch_size_ , parameters.num_heads_ ,
434477 num_present_sequence_length_tile, 2 });
@@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const
467510 ((context.AdapterInfo ().vendor == std::string_view{" qualcomm" } && parameters.head_size_ % 8 == 0 ) || parameters.head_size_ % 4 == 0 );
468511}
469512
513+ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV (onnxruntime::webgpu::ComputeContext& context,
514+ const WebgpuAttentionParameters& params,
515+ const Tensor* packedQKV,
516+ const Tensor* seqlen_k,
517+ const Tensor* cos_cache,
518+ const Tensor* sin_cache,
519+ Tensor* query,
520+ Tensor* present_key,
521+ Tensor* present_value,
522+ Tensor* indirect_buffer,
523+ uint32_t tile_size) {
524+ const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
525+ const auto head_size = params.head_size_ ;
526+
527+ int components = 1 ;
528+ // Currently we only support vectorization when RoPE is not interleaved
529+ if (!params.rotary_interleaved_ ) {
530+ if ((params.head_size_ % 4 == 0 ) && (half_rotary_embedding_dim % 4 == 0 )) {
531+ components = 4 ;
532+ } else if ((params.head_size_ % 2 == 0 ) && (half_rotary_embedding_dim % 2 == 0 )) {
533+ components = 2 ;
534+ }
535+ }
536+ // Adjust dimensions for vectorization
537+ const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components;
538+ const auto head_size_vec = head_size / components;
539+
540+ // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
541+ // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
542+ // = head_size - half_rotary_dim
543+ const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec;
544+ auto dispatch_size = static_cast <uint32_t >(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);
545+
546+ // Extract present_sequence_length from present_key tensor shape
547+ const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t >(present_key->Shape ()[2 ]);
548+
549+ const bool prepare_indirect_dispatch = (indirect_buffer != nullptr );
550+
551+ SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program (params.rotary_interleaved_ , prepare_indirect_dispatch);
552+ program
553+ .CacheHint (params.rotary_interleaved_ , prepare_indirect_dispatch)
554+ .AddInput ({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
555+ .AddInputs ({
556+ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
557+ {cos_cache, ProgramTensorMetadataDependency::Rank, components},
558+ {sin_cache, ProgramTensorMetadataDependency::Rank, components},
559+ });
560+ program.AddOutputs ({{query, ProgramTensorMetadataDependency::None, components},
561+ {present_key, ProgramTensorMetadataDependency::None, components},
562+ {present_value, ProgramTensorMetadataDependency::None, components}});
563+
564+ if (prepare_indirect_dispatch) {
565+ program.AddOutput ({indirect_buffer, ProgramTensorMetadataDependency::None});
566+ }
567+
568+ program.AddUniformVariables ({
569+ {static_cast <uint32_t >(params.sequence_length_ )},
570+ {static_cast <uint32_t >(params.hidden_size_ / components)},
571+ {static_cast <uint32_t >(params.kv_hidden_size_ / components)},
572+ {static_cast <uint32_t >(params.num_heads_ )},
573+ {static_cast <uint32_t >(params.kv_num_heads_ )},
574+ {head_size_vec},
575+ {half_rotary_embedding_dim_vec},
576+ {present_sequence_length},
577+ {tile_size},
578+ {static_cast <uint32_t >(dispatch_size)},
579+ });
580+
581+ program.SetDispatchGroupSize ((dispatch_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE);
582+ return context.RunProgram (program);
583+ }
584+
470585} // namespace webgpu
471586} // namespace contrib
472587} // namespace onnxruntime
0 commit comments