Skip to content

Commit 6af6346

Browse files
committed
temp local window size
add log info temp Support local_window_size for WebNN GQA
1 parent 423a03f commit 6af6346

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
107107
ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape");
108108

109109
NodeAttrHelper helper(node);
110+
const int32_t local_window_size = helper.Get("local_window_size", -1);
110111
const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0);
111112
const uint32_t num_heads = helper.Get("num_heads", 0);
112113

@@ -290,13 +291,13 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
290291
| |
291292
+-------------------------------> Lesser <---------------------Transpose (1,0)
292293
|
293-
1 ---> Where <--- finfo_min (minimum value of FP32)
294+
1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32)
294295
|
295296
attention_bias
296297
*/
297298
const std::vector<int32_t> mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length,
298299
1);
299-
std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) +
300+
std::string mask_shape_ones_shape_name = "webnn_GQA_mask_shape_ones_" + std::to_string(batch_size) +
300301
"_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) +
301302
"_" + std::to_string(past_sequence_length);
302303
emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant<int32_t>(
@@ -315,7 +316,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
315316
std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1);
316317

317318
std::string pre_neq_right_data_range_name =
318-
"webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length);
319+
"webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length);
319320
emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant<int32_t>(
320321
ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range,
321322
std::vector<uint32_t>({qkv_sequence_length}));
@@ -333,10 +334,49 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
333334
emscripten::val neq_right =
334335
model_builder.GetBuilder().call<emscripten::val>("transpose", expanded_neq_right, transpose_options);
335336

336-
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition");
337-
emscripten::val condition =
337+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1");
338+
emscripten::val condition_1 =
338339
model_builder.GetBuilder().call<emscripten::val>("lesser", neq_left, neq_right, common_options);
339340

341+
emscripten::val condition = condition_1;
342+
// For local window size not equal to -1, new attention mask pattern for applying sliding window
343+
/*
344+
condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true)
345+
| |
346+
| Lesser <--- local_window_size
347+
| |
348+
LogicalAnd <----------------- condition_2
349+
|
350+
new attn_mask
351+
*/
352+
if (local_window_size != -1) {
353+
emscripten::val console = emscripten::val::global("console");
354+
console.call<void>("log", emscripten::val("local window size is not -1."));
355+
// Cast condition
356+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast");
357+
emscripten::val casted_condition_1 =
358+
model_builder.GetBuilder().call<emscripten::val>("cast", condition_1, emscripten::val("int32"), common_options);
359+
360+
cumsum_options = emscripten::val::object();
361+
cumsum_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cumsum");
362+
cumsum_options.set("exclusive", true);
363+
cumsum_options.set("reversed", true);
364+
emscripten::val neq_left_2 = model_builder.GetBuilder().call<emscripten::val>(
365+
"cumulativeSum", casted_condition_1, gsl::narrow<uint32_t>(3), cumsum_options);
366+
367+
emscripten::val local_window_size_constant =
368+
model_builder.CreateOrGetConstant<int>(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1});
369+
370+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2");
371+
emscripten::val condition_2 =
372+
model_builder.GetBuilder().call<emscripten::val>("lesser", neq_left_2, local_window_size_constant, common_options);
373+
374+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and");
375+
condition = model_builder.GetBuilder().call<emscripten::val>(
376+
"logicalAnd", condition_1, condition_2, common_options);
377+
console.call<void>("log", condition_2);
378+
}
379+
340380
emscripten::val value_one_constant =
341381
model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1});
342382

0 commit comments

Comments
 (0)