Skip to content

Commit 96926a0

Browse files
authored
[webgpu] Fused CopyKVCache and SplitPackedQKVWithRotaryEmbedding as SplitPackedQKVWithRotaryEmbeddingAndCopyKV (#26563)
### Description <!-- Describe your changes. --> Create a ultimated fused path called SplitPackedQKVWithRotaryEmbeddingAndCopyKV which fused SplitPackedQKVWithRotaryEmbedding and CopyKVCache. When use flash attention and static kv cache is enabled, run it. We did the following things: - Support components for existed SplitPackedQKVWithRotaryEmbedding - Fused it and copykvcache as new SplitPackedQKVWithRotaryEmbeddingAndCopyKV ### Motivation and Context On NV5080, the token generation speed improve ~4%. | generation tps | Before | After | |--------|--------|-------| | NV5080 | 135 | **141** | | Intel | 15.3 | 15.4 | | Mac | 71.2 | 71.8 |
1 parent bdf8dc2 commit 96926a0

File tree

5 files changed

+349
-53
lines changed

5 files changed

+349
-53
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 139 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,32 @@ namespace onnxruntime {
1616
namespace contrib {
1717
namespace webgpu {
1818

19+
Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
20+
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
21+
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
22+
const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform);
23+
const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform);
24+
25+
const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform);
26+
const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform);
27+
const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform);
28+
29+
if (prepare_indirect_dispatch_) {
30+
sh.AddOutput("indirect_buffer", ShaderUsage::None);
31+
}
32+
33+
return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
34+
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
35+
WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_),
36+
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
37+
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
38+
WGSL_TEMPLATE_VARIABLE(present_key, present_key),
39+
WGSL_TEMPLATE_VARIABLE(present_value, present_value),
40+
WGSL_TEMPLATE_VARIABLE(query, query),
41+
WGSL_TEMPLATE_VARIABLE(seqlens, seqlens),
42+
WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache));
43+
}
44+
1945
Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
2046
// Expectations are
2147
// qkv have same number of heads and hidden dimension (head size).
@@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
351377

352378
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
353379
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
354-
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
380+
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
381+
const Tensor* cos_cache, const Tensor* sin_cache) {
382+
constexpr uint32_t tile_size = 64;
383+
355384
// Extract present_sequence_length directly from present_key tensor shape:
356385
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
357386
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);
358387

359388
const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled();
360389

390+
// Declare query_output at function scope to ensure it persists throughout the function
391+
Tensor query_output;
392+
393+
// Create indirect dispatch buffer if using indirect dispatch
394+
Tensor* indirect_buffer_ptr = nullptr;
395+
Tensor indirect_buffer;
396+
397+
// Prepare indirect dispatch buffer for decode path with static KV cache
398+
const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
399+
parameters.past_present_share_buffer_ &&
400+
seqlen_k != nullptr &&
401+
context.IsGraphCaptureEnabled();
402+
if (use_indirect_dispatch) {
403+
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
404+
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
405+
indirect_buffer_ptr = &indirect_buffer;
406+
}
407+
408+
const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr);
409+
410+
if (do_rotary) {
411+
ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input.");
412+
ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache.");
413+
414+
// Q points to the packed QKV tensor in this case, create query output tensor
415+
query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
416+
417+
ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters,
418+
Q, seqlen_k,
419+
cos_cache, sin_cache,
420+
&query_output, present_key, present_value,
421+
indirect_buffer_ptr, tile_size));
422+
Q = &query_output;
423+
} else {
424+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr));
425+
}
426+
361427
if (parameters.sequence_length_ > 1) {
362-
const uint32_t tile_size = 64;
363-
// For encode path, use the original CopyKVCache without indirect dispatch preparation
364-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
365428
bool has_attention_bias = attention_bias != nullptr;
366429
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
367430
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
406469
parameters.sequence_length_, present_sequence_length});
407470
const TensorShape qk_shape(qk_dims);
408471
Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape);
409-
constexpr uint32_t tile_size = 64;
410472
const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size;
411473
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;
412474

413-
// Determine if we should use indirect dispatch
414-
const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
415-
seqlen_k != nullptr &&
416-
context.IsGraphCaptureEnabled();
417-
418-
// Create indirect dispatch buffer if using indirect dispatch
419-
Tensor* indirect_buffer_ptr = nullptr;
420-
Tensor indirect_buffer;
421-
if (use_indirect_dispatch) {
422-
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
423-
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
424-
indirect_buffer_ptr = &indirect_buffer;
425-
// Use the fused CopyKVCache that also prepares the indirect dispatch buffer
426-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
427-
} else {
428-
// Use the original CopyKVCache without indirect dispatch preparation
429-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
430-
}
431-
432475
// The metadata is used to store the max and sum of each tile.
433476
const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_,
434477
num_present_sequence_length_tile, 2});
@@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const
467510
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
468511
}
469512

513+
Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
514+
const WebgpuAttentionParameters& params,
515+
const Tensor* packedQKV,
516+
const Tensor* seqlen_k,
517+
const Tensor* cos_cache,
518+
const Tensor* sin_cache,
519+
Tensor* query,
520+
Tensor* present_key,
521+
Tensor* present_value,
522+
Tensor* indirect_buffer,
523+
uint32_t tile_size) {
524+
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
525+
const auto head_size = params.head_size_;
526+
527+
int components = 1;
528+
// Currently we only support vectorization when RoPE is not interleaved
529+
if (!params.rotary_interleaved_) {
530+
if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) {
531+
components = 4;
532+
} else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) {
533+
components = 2;
534+
}
535+
}
536+
// Adjust dimensions for vectorization
537+
const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components;
538+
const auto head_size_vec = head_size / components;
539+
540+
// Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
541+
// work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
542+
// = head_size - half_rotary_dim
543+
const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec;
544+
auto dispatch_size = static_cast<uint32_t>(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);
545+
546+
// Extract present_sequence_length from present_key tensor shape
547+
const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t>(present_key->Shape()[2]);
548+
549+
const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
550+
551+
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch);
552+
program
553+
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch)
554+
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
555+
.AddInputs({
556+
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
557+
{cos_cache, ProgramTensorMetadataDependency::Rank, components},
558+
{sin_cache, ProgramTensorMetadataDependency::Rank, components},
559+
});
560+
program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components},
561+
{present_key, ProgramTensorMetadataDependency::None, components},
562+
{present_value, ProgramTensorMetadataDependency::None, components}});
563+
564+
if (prepare_indirect_dispatch) {
565+
program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
566+
}
567+
568+
program.AddUniformVariables({
569+
{static_cast<uint32_t>(params.sequence_length_)},
570+
{static_cast<uint32_t>(params.hidden_size_ / components)},
571+
{static_cast<uint32_t>(params.kv_hidden_size_ / components)},
572+
{static_cast<uint32_t>(params.num_heads_)},
573+
{static_cast<uint32_t>(params.kv_num_heads_)},
574+
{head_size_vec},
575+
{half_rotary_embedding_dim_vec},
576+
{present_sequence_length},
577+
{tile_size},
578+
{static_cast<uint32_t>(dispatch_size)},
579+
});
580+
581+
program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
582+
return context.RunProgram(program);
583+
}
584+
470585
} // namespace webgpu
471586
} // namespace contrib
472587
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,32 @@ namespace webgpu {
1515

1616
using namespace onnxruntime::webgpu;
1717

18+
class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram> {
19+
public:
20+
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch)
21+
: Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"},
22+
interleaved_(interleaved),
23+
prepare_indirect_dispatch_(prepare_indirect_dispatch) {}
24+
25+
Status GenerateShaderCode(ShaderHelper& sh) const override;
26+
27+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
28+
{"sequence_length", ProgramUniformVariableDataType::Uint32},
29+
{"hidden_size", ProgramUniformVariableDataType::Uint32},
30+
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32},
31+
{"num_heads", ProgramUniformVariableDataType::Uint32},
32+
{"kv_num_heads", ProgramUniformVariableDataType::Uint32},
33+
{"head_size", ProgramUniformVariableDataType::Uint32},
34+
{"half_rotary_dim", ProgramUniformVariableDataType::Uint32},
35+
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
36+
{"tile_size", ProgramUniformVariableDataType::Uint32},
37+
{"dispatch_size", ProgramUniformVariableDataType::Uint32});
38+
39+
private:
40+
const bool interleaved_;
41+
const bool prepare_indirect_dispatch_;
42+
};
43+
1844
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
1945
public:
2046
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
@@ -145,10 +171,24 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD
145171

146172
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
147173
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
148-
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
174+
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr,
175+
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr);
149176

150177
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
151178
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
179+
180+
// Split packed QKV with Q/K rotary embedding and copy KV cache fusion
181+
Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
182+
const WebgpuAttentionParameters& params,
183+
const Tensor* packedQKV,
184+
const Tensor* seqlen_k,
185+
const Tensor* cos_cache,
186+
const Tensor* sin_cache,
187+
Tensor* query,
188+
Tensor* present_key,
189+
Tensor* present_value,
190+
Tensor* indirect_buffer,
191+
uint32_t tile_size);
152192
} // namespace webgpu
153193
} // namespace contrib
154194
} // namespace onnxruntime

0 commit comments

Comments
 (0)