@@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
107107 ORT_RETURN_IF_NOT (GetShape (*input_defs[3 ], input_past_k_shape, logger), " Cannot get past_key shape" );
108108
109109 NodeAttrHelper helper (node);
110+ const int32_t local_window_size = helper.Get (" local_window_size" , -1 );
110111 const uint32_t kv_num_heads = helper.Get (" kv_num_heads" , 0 );
111112 const uint32_t num_heads = helper.Get (" num_heads" , 0 );
112113
@@ -290,18 +291,17 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
290291 | |
291292 +-------------------------------> Lesser <---------------------Transpose (1,0)
292293 |
293- 1 ---> Where <--- finfo_min (minimum value of FP32)
294+ 1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32)
294295 |
295296 attention_bias
296297 */
297- const std::vector<int32_t > mask_shape_ones_shape (batch_size * num_heads * qkv_sequence_length * past_sequence_length,
298- 1 );
299- std::string mask_shape_ones_shape_name = " webnn_GQA_left_constant_of_scatter_indices_" + std::to_string (batch_size) +
300- " _" + std::to_string (num_heads) + " _" + std::to_string (qkv_sequence_length) +
301- " _" + std::to_string (past_sequence_length);
302- emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant <int32_t >(
303- ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape,
304- std::vector<uint32_t >({batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
298+ emscripten::val value_int_one_constant =
299+ model_builder.CreateOrGetConstant <int >(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1 , {1 });
300+
301+ std::vector<uint32_t > mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length};
302+ common_options.set (" label" , node.Name () + " _/GQA/GQA_mask_shape_ones/expand" );
303+ emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder ().call <emscripten::val>(
304+ " expand" , value_int_one_constant, emscripten::val::array (mask_shape_ones_shape), common_options);
305305
306306 emscripten::val cumsum_options = emscripten::val::object ();
307307 cumsum_options.set (" label" , node.Name () + " _range_of_mask_shape" );
@@ -315,7 +315,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
315315 std::iota (pre_neq_right_data_range.begin (), pre_neq_right_data_range.end (), 1 );
316316
317317 std::string pre_neq_right_data_range_name =
318- " webnn_GQA_left_constant_of_scatter_indices_ " + std::to_string (qkv_sequence_length);
318+ " webnn_GQA_pre_neq_right_data_range_ " + std::to_string (qkv_sequence_length);
319319 emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant <int32_t >(
320320 ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range,
321321 std::vector<uint32_t >({qkv_sequence_length}));
@@ -333,10 +333,46 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
333333 emscripten::val neq_right =
334334 model_builder.GetBuilder ().call <emscripten::val>(" transpose" , expanded_neq_right, transpose_options);
335335
336- common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition " );
337- emscripten::val condition =
336+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_1 " );
337+ emscripten::val condition_1 =
338338 model_builder.GetBuilder ().call <emscripten::val>(" lesser" , neq_left, neq_right, common_options);
339339
340+ emscripten::val condition = condition_1;
341+ // For local window size not equal to -1, new attention mask pattern for applying sliding window
342+ /*
343+ condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true)
344+ | |
345+ | Lesser <--- local_window_size
346+ | |
347+ LogicalAnd <----------------- condition_2
348+ |
349+ new attn_mask
350+ */
351+ if (local_window_size != -1 ) {
352+ // Cast condition
353+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2/cast" );
354+ emscripten::val casted_condition_1 =
355+ model_builder.GetBuilder ().call <emscripten::val>(" cast" , condition_1, emscripten::val (" int32" ), common_options);
356+
357+ cumsum_options = emscripten::val::object ();
358+ cumsum_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2/cumsum" );
359+ cumsum_options.set (" exclusive" , true );
360+ cumsum_options.set (" reversed" , true );
361+ emscripten::val neq_left_2 = model_builder.GetBuilder ().call <emscripten::val>(
362+ " cumulativeSum" , casted_condition_1, gsl::narrow<uint32_t >(3 ), cumsum_options);
363+
364+ emscripten::val local_window_size_constant =
365+ model_builder.CreateOrGetConstant <int >(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1 });
366+
367+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2" );
368+ emscripten::val condition_2 =
369+ model_builder.GetBuilder ().call <emscripten::val>(" lesser" , neq_left_2, local_window_size_constant, common_options);
370+
371+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition/and" );
372+ condition = model_builder.GetBuilder ().call <emscripten::val>(
373+ " logicalAnd" , condition_1, condition_2, common_options);
374+ }
375+
340376 emscripten::val value_one_constant =
341377 model_builder.CreateOrGetConstant <float >(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1 , {1 });
342378
0 commit comments