Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,14 @@
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
round_str = "round";
}
std::string use_sqrt_for_pow;
std::string use_pow_shortcut;

Check warning on line 325 in onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc:325: Add #include <string> for string [build/include_what_you_use] [4]
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
// use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0
// use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5
use_sqrt_for_pow =
" else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
use_pow_shortcut =
" else if (b == 2.0) {\n"
" return a * a;\n"
" } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
" return sqrt(a);\n"
" }\n";
}
Expand All @@ -337,7 +340,7 @@
" } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n"
" return input_a_element_t(pow(f32(a), b)); // NaN\n"
" }\n"
<< use_sqrt_for_pow
<< use_pow_shortcut
<< " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n"
<< "}\n"
"fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n"
Expand Down
Loading