Skip to content

Commit bdf8dc2

Browse files
authored
[WebNN EP] Support local attention feature for GQA (#26565)
### Description <!-- Describe your changes. --> Support the `local_window_size` attribute in **GroupQueryAttention** Operator, which is designed for sliding window attention and may influence the attention mask pattern. For local window size not equal to -1, new attention mask pattern will be created as follows for applying sliding window. ``` condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true) | | | Lesser <--- local_window_size | | LogicalAnd <----------------- condition_2 | new attn_mask ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent ff0715d commit bdf8dc2

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
107107
ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape");
108108

109109
NodeAttrHelper helper(node);
110+
const int32_t local_window_size = helper.Get("local_window_size", -1);
110111
const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0);
111112
const uint32_t num_heads = helper.Get("num_heads", 0);
112113

@@ -290,18 +291,17 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
290291
| |
291292
+-------------------------------> Lesser <---------------------Transpose (1,0)
292293
|
293-
1 ---> Where <--- finfo_min (minimum value of FP32)
294+
1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32)
294295
|
295296
attention_bias
296297
*/
297-
const std::vector<int32_t> mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length,
298-
1);
299-
std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) +
300-
"_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) +
301-
"_" + std::to_string(past_sequence_length);
302-
emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant<int32_t>(
303-
ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape,
304-
std::vector<uint32_t>({batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
298+
emscripten::val value_int_one_constant =
299+
model_builder.CreateOrGetConstant<int>(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1});
300+
301+
std::vector<uint32_t> mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length};
302+
common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand");
303+
emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call<emscripten::val>(
304+
"expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options);
305305

306306
emscripten::val cumsum_options = emscripten::val::object();
307307
cumsum_options.set("label", node.Name() + "_range_of_mask_shape");
@@ -315,7 +315,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
315315
std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1);
316316

317317
std::string pre_neq_right_data_range_name =
318-
"webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length);
318+
"webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length);
319319
emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant<int32_t>(
320320
ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range,
321321
std::vector<uint32_t>({qkv_sequence_length}));
@@ -333,10 +333,46 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
333333
emscripten::val neq_right =
334334
model_builder.GetBuilder().call<emscripten::val>("transpose", expanded_neq_right, transpose_options);
335335

336-
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition");
337-
emscripten::val condition =
336+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1");
337+
emscripten::val condition_1 =
338338
model_builder.GetBuilder().call<emscripten::val>("lesser", neq_left, neq_right, common_options);
339339

340+
emscripten::val condition = condition_1;
341+
// For local window size not equal to -1, new attention mask pattern for applying sliding window
342+
/*
343+
condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true)
344+
| |
345+
| Lesser <--- local_window_size
346+
| |
347+
LogicalAnd <----------------- condition_2
348+
|
349+
new attn_mask
350+
*/
351+
if (local_window_size != -1) {
352+
// Cast condition
353+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast");
354+
emscripten::val casted_condition_1 =
355+
model_builder.GetBuilder().call<emscripten::val>("cast", condition_1, emscripten::val("int32"), common_options);
356+
357+
cumsum_options = emscripten::val::object();
358+
cumsum_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cumsum");
359+
cumsum_options.set("exclusive", true);
360+
cumsum_options.set("reversed", true);
361+
emscripten::val neq_left_2 = model_builder.GetBuilder().call<emscripten::val>(
362+
"cumulativeSum", casted_condition_1, gsl::narrow<uint32_t>(3), cumsum_options);
363+
364+
emscripten::val local_window_size_constant =
365+
model_builder.CreateOrGetConstant<int>(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1});
366+
367+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2");
368+
emscripten::val condition_2 =
369+
model_builder.GetBuilder().call<emscripten::val>("lesser", neq_left_2, local_window_size_constant, common_options);
370+
371+
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and");
372+
condition = model_builder.GetBuilder().call<emscripten::val>(
373+
"logicalAnd", condition_1, condition_2, common_options);
374+
}
375+
340376
emscripten::val value_one_constant =
341377
model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1});
342378

0 commit comments

Comments
 (0)