Skip to content

Commit d7ca50b

Browse files
committed
Switch comparison order
1 parent ddaf674 commit d7ca50b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,15 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
322322
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
323323
round_str = "round";
324324
}
325-
std::string use_sqrt_or_mul_for_pow;
325+
std::string use_pow_shortcut;
326326
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
327-
// use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5
328327
// use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0
329-
use_sqrt_or_mul_for_pow =
330-
" else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
331-
" return sqrt(a);\n"
332-
" } else if (b == 2.0) {\n"
328+
// use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5
329+
use_pow_shortcut =
330+
" else if (b == 2.0) {\n"
333331
" return a * a;\n"
332+
" } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
333+
" return sqrt(a);\n"
334334
" }\n";
335335
}
336336

@@ -340,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
340340
" } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n"
341341
" return input_a_element_t(pow(f32(a), b)); // NaN\n"
342342
" }\n"
343-
<< use_sqrt_or_mul_for_pow
343+
<< use_pow_shortcut
344344
<< " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n"
345345
<< "}\n"
346346
"fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n"

0 commit comments

Comments
 (0)