diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 82645e30082e6..3c974ef5133c0 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,11 +322,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_sqrt_for_pow; + std::string use_pow_shortcut; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_sqrt_for_pow = - " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_pow_shortcut = + " else if (b == 2.0) {\n" + " return a * a;\n" + " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_sqrt_for_pow + << use_pow_shortcut << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n"