Skip to content

Commit c32777f

Browse files
committed
refactor
1 parent d1f03b5 commit c32777f

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,22 +297,20 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
297297
// query points to packed QKV, K and V are nullptr since they're not needed
298298
return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value,
299299
present_value, parameters, context, seqlen_k, cos_cache, sin_cache);
300-
} else {
301-
// Fused: splitQKV + rotary QK
302-
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
303-
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
304-
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
305-
ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbedding(context, parameters,
306-
query, seqlen_k,
307-
cos_cache, sin_cache,
308-
&qSplit, &kSplit, &vSplit));
309-
key = &kSplit;
310-
value = &vSplit;
311300
}
312-
301+
// Fused: splitQKV + rotary QK
302+
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
303+
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
304+
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
305+
ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbedding(context, parameters,
306+
query, seqlen_k,
307+
cos_cache, sin_cache,
308+
&qSplit, &kSplit, &vSplit));
313309
parameters.is_packed_qkv_ = false;
314310
parameters.qkv_format_ = Q_K_V_BSNH;
315311
query = &qSplit;
312+
key = &kSplit;
313+
value = &vSplit;
316314
} else {
317315
if (parameters.is_packed_qkv_) {
318316
// splitQKV

0 commit comments

Comments
 (0)