-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[webgpu] add support of output_qk for MHA #26553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
b6e6d54 to
7d0139a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request adds support for the output_qk parameter in the MultiHeadAttention operator for the WebGPU execution provider. The output_qk output contains the scaled Q*K^T attention scores (before softmax), which is useful for debugging and visualization purposes.
Key Changes
- Modified
ComputeContextto add aCopyTensormethod for copying tensor data, requiring anOpKernelreference in the constructor - Updated
ApplyAttentionfunction signature to accept an optionaloutput_qkparameter - Modified MultiHeadAttention to compute and output QK scores when requested, with flash attention disabled when QK output is needed
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
compute_context.h |
Added CopyTensor method and updated constructor to accept OpKernel reference |
compute_context.cc |
Updated constructor implementation to match header changes |
multihead_attention.cc |
Added logic to create output_qk tensor and conditionally disable flash attention when QK output is requested |
group_query_attention.cc |
Passed nullptr for output_qk parameter as GQA doesn't support this feature |
attention_common.h |
Updated ApplyAttention function signature to include output_qk parameter |
attention.cc |
Implemented QK score copying logic and updated function signatures |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
### Description
Fixed compilation error in `webgpu_kernel.cc` where `ComputeContext`
constructor call was missing the required `OpKernel` parameter.
**Change:**
```cpp
// Before (compilation error - 3 params, needs 4)
ComputeContext context{*p_op_kernel_context, ep_, webgpu_context};
// After (correct)
ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context};
```
The constructor signature requires the `OpKernel` reference as the
second parameter. Since `WebGpuKernel` inherits from `OpKernel`, passing
`*this` satisfies the requirement. The `op_kernel_` member is used
internally by `CopyTensor()` to access the data transfer manager.
### Motivation and Context
Addresses review feedback from
#26553 (comment)
which identified the missing parameter that would cause a compilation
failure.
<!-- START COPILOT CODING AGENT TIPS -->
---
💡 You can make Copilot smarter by setting up custom instructions,
customizing its development environment and configuring Model Context
Protocol (MCP) servers. Learn more [Copilot coding agent
tips](https://gh.io/copilot-coding-agent-tips) in the docs.
---------
Co-authored-by: copilot-swe-agent[bot] <[email protected]>
Co-authored-by: fs-eire <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| TensorShape output_qk_shape(output_qk_dims); | ||
| Tensor* output_qk = context.Output(3, output_qk_shape); | ||
|
|
||
| if (output_qk == nullptr && // Flash attention does not output QK scores |
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Extra space between nullptr && should be removed for consistent formatting.
| if (output_qk == nullptr && // Flash attention does not output QK scores | |
| if (output_qk == nullptr && // Flash attention does not output QK scores |
| parameters, past_sequence_length, total_sequence_length, seqlen_k)); | ||
|
|
||
| if (output_qk != nullptr) { | ||
| // Copy the attention scores (scaled Q*K^T) to output_qk |
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is incomplete. According to line 172-175 of the ComputeAttentionProbs function, the probs tensor contains sum * uniforms.alpha + attention_bias[outputIdx] when attention_bias is present, not just scaled QK^T. The comment should be updated to reflect this: 'Copy the attention scores (scaled QK^T with attention_bias if present) to output_qk' or 'Copy the raw attention scores (before softmax) to output_qk'.
| // Copy the attention scores (scaled Q*K^T) to output_qk | |
| // Copy the raw attention scores (scaled Q*K^T plus attention_bias if present, before softmax) to output_qk |
Description
WebGPU EP does not support MHA's
qkoutput yet. This PR makes it handlingqkoutput correctly.