Skip to content

Commit ddaf674

Browse files
committed
[webgpu] Use multiplication instead of pow if exponent is 2
1 parent f02a640 commit ddaf674

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,15 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
322322
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
323323
round_str = "round";
324324
}
325-
std::string use_sqrt_for_pow;
325+
std::string use_sqrt_or_mul_for_pow;
326326
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
327327
// use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5
328-
use_sqrt_for_pow =
328+
// use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0
329+
use_sqrt_or_mul_for_pow =
329330
" else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
330331
" return sqrt(a);\n"
332+
" } else if (b == 2.0) {\n"
333+
" return a * a;\n"
331334
" }\n";
332335
}
333336

@@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
337340
" } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n"
338341
" return input_a_element_t(pow(f32(a), b)); // NaN\n"
339342
" }\n"
340-
<< use_sqrt_for_pow
343+
<< use_sqrt_or_mul_for_pow
341344
<< " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n"
342345
<< "}\n"
343346
"fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n"

0 commit comments

Comments
 (0)