@@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
107107 ORT_RETURN_IF_NOT (GetShape (*input_defs[3 ], input_past_k_shape, logger), " Cannot get past_key shape" );
108108
109109 NodeAttrHelper helper (node);
110+ const int32_t local_window_size = helper.Get (" local_window_size" , -1 );
110111 const uint32_t kv_num_heads = helper.Get (" kv_num_heads" , 0 );
111112 const uint32_t num_heads = helper.Get (" num_heads" , 0 );
112113
@@ -290,13 +291,13 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
290291 | |
291292 +-------------------------------> Lesser <---------------------Transpose (1,0)
292293 |
293- 1 ---> Where <--- finfo_min (minimum value of FP32)
294+ 1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32)
294295 |
295296 attention_bias
296297 */
297298 const std::vector<int32_t > mask_shape_ones_shape (batch_size * num_heads * qkv_sequence_length * past_sequence_length,
298299 1 );
299- std::string mask_shape_ones_shape_name = " webnn_GQA_left_constant_of_scatter_indices_ " + std::to_string (batch_size) +
300+ std::string mask_shape_ones_shape_name = " webnn_GQA_mask_shape_ones_ " + std::to_string (batch_size) +
300301 " _" + std::to_string (num_heads) + " _" + std::to_string (qkv_sequence_length) +
301302 " _" + std::to_string (past_sequence_length);
302303 emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant <int32_t >(
@@ -315,7 +316,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
315316 std::iota (pre_neq_right_data_range.begin (), pre_neq_right_data_range.end (), 1 );
316317
317318 std::string pre_neq_right_data_range_name =
318- " webnn_GQA_left_constant_of_scatter_indices_ " + std::to_string (qkv_sequence_length);
319+ " webnn_GQA_pre_neq_right_data_range_ " + std::to_string (qkv_sequence_length);
319320 emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant <int32_t >(
320321 ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range,
321322 std::vector<uint32_t >({qkv_sequence_length}));
@@ -333,10 +334,49 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
333334 emscripten::val neq_right =
334335 model_builder.GetBuilder ().call <emscripten::val>(" transpose" , expanded_neq_right, transpose_options);
335336
336- common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition " );
337- emscripten::val condition =
337+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_1 " );
338+ emscripten::val condition_1 =
338339 model_builder.GetBuilder ().call <emscripten::val>(" lesser" , neq_left, neq_right, common_options);
339340
341+ emscripten::val condition = condition_1;
342+ // For local window size not equal to -1, new attention mask pattern for applying sliding window
343+ /*
344+ condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true)
345+ | |
346+ | Lesser <--- local_window_size
347+ | |
348+ LogicalAnd <----------------- condition_2
349+ |
350+ new attn_mask
351+ */
352+ if (local_window_size != -1 ) {
353+ emscripten::val console = emscripten::val::global (" console" );
354+ console.call <void >(" log" , emscripten::val (" local window size is not -1." ));
355+ // Cast condition
356+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2/cast" );
357+ emscripten::val casted_condition_1 =
358+ model_builder.GetBuilder ().call <emscripten::val>(" cast" , condition_1, emscripten::val (" int32" ), common_options);
359+
360+ cumsum_options = emscripten::val::object ();
361+ cumsum_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2/cumsum" );
362+ cumsum_options.set (" exclusive" , true );
363+ cumsum_options.set (" reversed" , true );
364+ emscripten::val neq_left_2 = model_builder.GetBuilder ().call <emscripten::val>(
365+ " cumulativeSum" , casted_condition_1, gsl::narrow<uint32_t >(3 ), cumsum_options);
366+
367+ emscripten::val local_window_size_constant =
368+ model_builder.CreateOrGetConstant <int >(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1 });
369+
370+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2" );
371+ emscripten::val condition_2 =
372+ model_builder.GetBuilder ().call <emscripten::val>(" lesser" , neq_left_2, local_window_size_constant, common_options);
373+
374+ common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition/and" );
375+ condition = model_builder.GetBuilder ().call <emscripten::val>(
376+ " logicalAnd" , condition_1, condition_2, common_options);
377+ console.call <void >(" log" , condition_2);
378+ }
379+
340380 emscripten::val value_one_constant =
341381 model_builder.CreateOrGetConstant <float >(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1 , {1 });
342382
0 commit comments