Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink,
const Tensor* seqlen_k, int local_window_size) {
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length =
Expand All @@ -534,6 +534,11 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
parameters, past_sequence_length, total_sequence_length, seqlen_k));

if (output_qk != nullptr) {
// Copy the attention scores (scaled Q*K^T) to output_qk
Copy link

Copilot AI Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is incomplete. According to line 172-175 of the ComputeAttentionProbs function, the probs tensor contains sum * uniforms.alpha + attention_bias[outputIdx] when attention_bias is present, not just scaled QK^T. The comment should be updated to reflect this: 'Copy the attention scores (scaled QK^T with attention_bias if present) to output_qk' or 'Copy the raw attention scores (before softmax) to output_qk'.

Suggested change
// Copy the attention scores (scaled Q*K^T) to output_qk
// Copy the raw attention scores (scaled Q*K^T plus attention_bias if present, before softmax) to output_qk

Copilot uses AI. Check for mistakes.
ORT_RETURN_IF_ERROR(context.CopyTensor(probs, *output_qk));
}

ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink, local_window_size));

Expand Down Expand Up @@ -730,7 +735,7 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)

// Apply the actual attention computation
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
/* present_value */ nullptr, parameters, context, nullptr, nullptr, -1);
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);
}

} // namespace webgpu
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr, int local_window_size = -1);

} // namespace webgpu
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q));
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
present_value, nullptr, parameters, context, head_sink, seqlen_k, local_window_size_);
}

TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
Expand All @@ -338,7 +338,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
parameters.v_head_size_, value, nullptr, 0, &V));
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
present_value, nullptr, parameters, context, head_sink, seqlen_k, local_window_size_);
}

KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture) {
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,17 @@
Tensor* present_key = context.Output(1, present_shape);
Tensor* present_value = context.Output(2, present_shape);

if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
std::vector<int64_t> output_qk_dims{

Check warning on line 97 in onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc:97: Add #include <vector> for vector<> [build/include_what_you_use] [4]
parameters.batch_size_,
parameters.num_heads_,
parameters.sequence_length_,
parameters.total_sequence_length_,
};
TensorShape output_qk_shape(output_qk_dims);
Tensor* output_qk = context.Output(3, output_qk_shape);

if (output_qk == nullptr && // Flash attention does not output QK scores
Copy link

Copilot AI Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Extra space between nullptr && should be removed for consistent formatting.

Suggested change
if (output_qk == nullptr && // Flash attention does not output QK scores
if (output_qk == nullptr && // Flash attention does not output QK scores

Copilot uses AI. Check for mistakes.
CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
}
Expand All @@ -108,7 +118,7 @@

if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context);
present_value, output_qk, parameters, context);
}

TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_,
Expand All @@ -127,7 +137,7 @@

// Compute the attention score and apply the score to V
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context);
present_value, output_qk, parameters, context);
}

} // namespace webgpu
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/webgpu/compute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

namespace onnxruntime {
namespace webgpu {
ComputeContext::ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep, WebGpuContext& webgpu_context)
ComputeContext::ComputeContext(OpKernelContext& kernel_context,
const OpKernel& op_kernel,
const WebGpuExecutionProvider& ep,
WebGpuContext& webgpu_context)
: webgpu_context_{webgpu_context},
kernel_context_{kernel_context},
op_kernel_{op_kernel},
ep_{ep} {
}

Expand Down
16 changes: 15 additions & 1 deletion onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <utility>

#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_provider.h"
#include "core/providers/webgpu/webgpu_execution_provider.h"

Expand Down Expand Up @@ -36,7 +37,10 @@ class ComputeContext final {
static const webgpu::BufferManager& Get(const ComputeContext& context);
};

ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep, WebGpuContext& webgpu_context);
ComputeContext(OpKernelContext& kernel_context,
const OpKernel& op_kernel,
const WebGpuExecutionProvider& ep,
WebGpuContext& webgpu_context);

~ComputeContext() = default;

Expand Down Expand Up @@ -132,6 +136,15 @@ class ComputeContext final {
return {data_type, std::forward<TensorShapeType>(shape), allocator};
}

//
// Copy data from a tensor to another tensor.
//
// This method assumes that both tensors have the same data size.
//
inline Status CopyTensor(const Tensor& src, Tensor& dst) {
return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst);
}

//
// Run a compute shader program.
//
Expand All @@ -142,6 +155,7 @@ class ComputeContext final {
private:
WebGpuContext& webgpu_context_;
OpKernelContext& kernel_context_;
const OpKernel& op_kernel_;
const WebGpuExecutionProvider& ep_;
};

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/webgpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ WebGpuKernel::WebGpuKernel(const OpKernelInfo& info)

Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const {
WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId());
ComputeContext context{*p_op_kernel_context, ep_, webgpu_context};
ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context};

if (webgpu_context.ValidationMode() >= ValidationMode::Full) {
webgpu_context.PushErrorScope();
Expand Down
Loading