Skip to content

Commit 91a9d02

Browse files
fs-eireCopilot
andauthored
[webgpu] add support of output_qk for MHA (#26553)
### Description WebGPU EP does not support MHA's `qk` output yet. This PR makes it handling `qk` output correctly. --------- Co-authored-by: Copilot <[email protected]>
1 parent 8ac5670 commit 91a9d02

File tree

7 files changed

+45
-12
lines changed

7 files changed

+45
-12
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
520520

521521
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
522522
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
523-
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink,
524-
const Tensor* seqlen_k, int local_window_size) {
523+
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
524+
const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) {
525525
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
526526
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
527527
const int total_sequence_length =
@@ -534,6 +534,11 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
534534
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
535535
parameters, past_sequence_length, total_sequence_length, seqlen_k));
536536

537+
if (output_qk != nullptr) {
538+
// Copy the attention scores (scaled Q*K^T) to output_qk
539+
ORT_RETURN_IF_ERROR(context.CopyTensor(probs, *output_qk));
540+
}
541+
537542
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
538543
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink, local_window_size));
539544

@@ -730,7 +735,7 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
730735

731736
// Apply the actual attention computation
732737
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
733-
/* present_value */ nullptr, parameters, context, nullptr, nullptr, -1);
738+
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);
734739
}
735740

736741
} // namespace webgpu

onnxruntime/contrib_ops/webgpu/bert/attention_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
124124

125125
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
126126
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
127-
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
127+
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
128128
const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr, int local_window_size = -1);
129129

130130
} // namespace webgpu

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
321321
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q));
322322
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
323323
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
324-
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
324+
present_value, nullptr, parameters, context, head_sink, seqlen_k, local_window_size_);
325325
}
326326

327327
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
@@ -338,7 +338,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
338338
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
339339
parameters.v_head_size_, value, nullptr, 0, &V));
340340
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
341-
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
341+
present_value, nullptr, parameters, context, head_sink, seqlen_k, local_window_size_);
342342
}
343343

344344
KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture) {

onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,17 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
9494
Tensor* present_key = context.Output(1, present_shape);
9595
Tensor* present_value = context.Output(2, present_shape);
9696

97-
if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
97+
std::vector<int64_t> output_qk_dims{
98+
parameters.batch_size_,
99+
parameters.num_heads_,
100+
parameters.sequence_length_,
101+
parameters.total_sequence_length_,
102+
};
103+
TensorShape output_qk_shape(output_qk_dims);
104+
Tensor* output_qk = context.Output(3, output_qk_shape);
105+
106+
if (output_qk == nullptr && // Flash attention does not output QK scores
107+
CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
98108
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
99109
present_value, parameters, context);
100110
}
@@ -108,7 +118,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
108118

109119
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
110120
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
111-
present_value, parameters, context);
121+
present_value, output_qk, parameters, context);
112122
}
113123

114124
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_,
@@ -127,7 +137,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
127137

128138
// Compute the attention score and apply the score to V
129139
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
130-
present_value, parameters, context);
140+
present_value, output_qk, parameters, context);
131141
}
132142

133143
} // namespace webgpu

onnxruntime/core/providers/webgpu/compute_context.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
namespace onnxruntime {
88
namespace webgpu {
9-
ComputeContext::ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep, WebGpuContext& webgpu_context)
9+
ComputeContext::ComputeContext(OpKernelContext& kernel_context,
10+
const OpKernel& op_kernel,
11+
const WebGpuExecutionProvider& ep,
12+
WebGpuContext& webgpu_context)
1013
: webgpu_context_{webgpu_context},
1114
kernel_context_{kernel_context},
15+
op_kernel_{op_kernel},
1216
ep_{ep} {
1317
}
1418

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <utility>
99

10+
#include "core/framework/data_transfer_manager.h"
1011
#include "core/framework/execution_provider.h"
1112
#include "core/providers/webgpu/webgpu_execution_provider.h"
1213

@@ -36,7 +37,10 @@ class ComputeContext final {
3637
static const webgpu::BufferManager& Get(const ComputeContext& context);
3738
};
3839

39-
ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep, WebGpuContext& webgpu_context);
40+
ComputeContext(OpKernelContext& kernel_context,
41+
const OpKernel& op_kernel,
42+
const WebGpuExecutionProvider& ep,
43+
WebGpuContext& webgpu_context);
4044

4145
~ComputeContext() = default;
4246

@@ -132,6 +136,15 @@ class ComputeContext final {
132136
return {data_type, std::forward<TensorShapeType>(shape), allocator};
133137
}
134138

139+
//
140+
// Copy data from a tensor to another tensor.
141+
//
142+
// This method assumes that both tensors have the same data size.
143+
//
144+
inline Status CopyTensor(const Tensor& src, Tensor& dst) {
145+
return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst);
146+
}
147+
135148
//
136149
// Run a compute shader program.
137150
//
@@ -142,6 +155,7 @@ class ComputeContext final {
142155
private:
143156
WebGpuContext& webgpu_context_;
144157
OpKernelContext& kernel_context_;
158+
const OpKernel& op_kernel_;
145159
const WebGpuExecutionProvider& ep_;
146160
};
147161

onnxruntime/core/providers/webgpu/webgpu_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ WebGpuKernel::WebGpuKernel(const OpKernelInfo& info)
1616

1717
Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const {
1818
WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId());
19-
ComputeContext context{*p_op_kernel_context, ep_, webgpu_context};
19+
ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context};
2020

2121
if (webgpu_context.ValidationMode() >= ValidationMode::Full) {
2222
webgpu_context.PushErrorScope();

0 commit comments

Comments
 (0)