Skip to content

Commit d6e56cd

Browse files
committed
[webgpu] Throw errors for graph catpure when not implemented
1 parent 91a9d02 commit d6e56cd

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,9 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
522522
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
523523
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
524524
const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) {
525+
if (context.IsGraphCaptureEnabled()) {
526+
ORT_NOT_IMPLEMENTED("Graph capture not implemented for non flash attention path");
527+
}
525528
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
526529
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
527530
const int total_sequence_length =

0 commit comments

Comments
 (0)