@@ -295,14 +295,13 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
295295 |
296296 attention_bias
297297 */
298- const std::vector<int32_t > mask_shape_ones_shape (batch_size * num_heads * qkv_sequence_length * past_sequence_length,
299- 1 );
300- std::string mask_shape_ones_shape_name = " webnn_GQA_mask_shape_ones_" + std::to_string (batch_size) +
301- " _" + std::to_string (num_heads) + " _" + std::to_string (qkv_sequence_length) +
302- " _" + std::to_string (past_sequence_length);
303- emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant <int32_t >(
304- ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape,
305- std::vector<uint32_t >({batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
298+ emscripten::val value_int_one_constant =
299+ model_builder.CreateOrGetConstant <int >(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1 , {1 });
300+
301+ std::vector<uint32_t > mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length};
302+ common_options.set (" label" , node.Name () + " _/GQA/GQA_mask_shape_ones/expand" );
303+ emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder ().call <emscripten::val>(
304+ " expand" , value_int_one_constant, emscripten::val::array (mask_shape_ones_shape), common_options);
306305
307306 emscripten::val cumsum_options = emscripten::val::object ();
308307 cumsum_options.set (" label" , node.Name () + " _range_of_mask_shape" );
@@ -350,8 +349,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
350349 new attn_mask
351350 */
352351 if (local_window_size != -1 ) {
353- emscripten::val console = emscripten::val::global (" console" );
354- console.call <void >(" log" , emscripten::val (" local window size is not -1." ));
355352 // Cast condition
356353 common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition_2/cast" );
357354 emscripten::val casted_condition_1 =
@@ -374,7 +371,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
374371 common_options.set (" label" , node.Name () + " _/GQA/attn_mask/condition/and" );
375372 condition = model_builder.GetBuilder ().call <emscripten::val>(
376373 " logicalAnd" , condition_1, condition_2, common_options);
377- console.call <void >(" log" , condition_2);
378374 }
379375
380376 emscripten::val value_one_constant =
0 commit comments