Skip to content

Commit e9ef8ae

Browse files
committed
decompose large constant
1 parent 6af6346 commit e9ef8ae

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,13 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
295295
|
296296
attention_bias
297297
*/
298-
const std::vector<int32_t> mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length,
299-
1);
300-
std::string mask_shape_ones_shape_name = "webnn_GQA_mask_shape_ones_" + std::to_string(batch_size) +
301-
"_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) +
302-
"_" + std::to_string(past_sequence_length);
303-
emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant<int32_t>(
304-
ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape,
305-
std::vector<uint32_t>({batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
298+
emscripten::val value_int_one_constant =
299+
model_builder.CreateOrGetConstant<int>(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1});
300+
301+
std::vector<uint32_t> mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length};
302+
common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand");
303+
emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call<emscripten::val>(
304+
"expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options);
306305

307306
emscripten::val cumsum_options = emscripten::val::object();
308307
cumsum_options.set("label", node.Name() + "_range_of_mask_shape");
@@ -350,8 +349,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
350349
new attn_mask
351350
*/
352351
if (local_window_size != -1) {
353-
emscripten::val console = emscripten::val::global("console");
354-
console.call<void>("log", emscripten::val("local window size is not -1."));
355352
// Cast condition
356353
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast");
357354
emscripten::val casted_condition_1 =
@@ -374,7 +371,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
374371
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and");
375372
condition = model_builder.GetBuilder().call<emscripten::val>(
376373
"logicalAnd", condition_1, condition_2, common_options);
377-
console.call<void>("log", condition_2);
378374
}
379375

380376
emscripten::val value_one_constant =

0 commit comments

Comments
 (0)