Skip to content

Commit 4c43c66

Browse files
authored
[webgpu] Use multiplication instead of pow if exponent is 2 (#26667)
### Description More accurately compute Pow(2.0) on WebGPU EP. Reproduction script: ```py from onnx import helper, TensorProto import onnxruntime as ort import numpy as np # 1. Create the ONNX model # Define input and output input_info = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1]) output_info = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1]) # Create a constant tensor for the exponent (2.0) exponent_tensor = helper.make_tensor('exponent', TensorProto.FLOAT, [], [2.0]) exponent_node = helper.make_node('Constant', [], ['exponent_out'], value=exponent_tensor) # Create the Pow node # Pow takes two inputs: Base (X) and Power (exponent_out) pow_node = helper.make_node( 'Pow', inputs=['X', 'exponent_out'], outputs=['Y'], name='PowNode' ) # Create the graph graph_def = helper.make_graph( [exponent_node, pow_node], 'test-model', [input_info], [output_info] ) # Create the model model_def = helper.make_model(graph_def, producer_name='onnx-example') opset = model_def.opset_import[0] opset.version = 13 # Ensure opset version supports the operations # 2. Convert model to string (bytes) model_str = model_def.SerializeToString() # 3. Prepare input data np.random.seed(0) input_data = np.array([[-2e3]], dtype=np.float32) # 4. Run on CPUExecutionProvider sess_cpu = ort.InferenceSession(model_str, providers=['CPUExecutionProvider']) res_cpu = sess_cpu.run(['Y'], {'X': input_data})[0] print("CPU Result:", res_cpu) # 5. Run on WebGpuExecutionProvider sess_webgpu = ort.InferenceSession(model_str, providers=['WebGpuExecutionProvider']) res_webgpu = sess_webgpu.run(['Y'], {'X': input_data})[0] print("WebGPU Result:", res_webgpu) # Compare results diff = np.abs(res_cpu - res_webgpu) max_diff = diff.max().item() assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" print("Results match!") ``` currently produces ``` CPU Result: [[4.e+06]] WebGPU Result: [[3.999999e+06]] --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[1], [line 56](vscode-notebook-cell:?execution_count=1&line=56) 54 diff = np.abs(res_cpu - res_webgpu) 55 max_diff = diff.max().item() ---> [56](vscode-notebook-cell:?execution_count=1&line=56) assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" 57 print("Results match!") AssertionError: Results do not match within tolerance! Max diff: 1.0 ``` but with this PR: ``` CPU Result: [[4.e+06]] WebGPU Result: [[4.e+06]] Results match! ``` ### Motivation and Context Leads to downstream issues/inaccuracies for certain models, especially those which have larger values to compute pow(x,2) for. cc @guschmue
1 parent f02a640 commit 4c43c66

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
322322
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
323323
round_str = "round";
324324
}
325-
std::string use_sqrt_for_pow;
325+
std::string use_pow_shortcut;
326326
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
327+
// use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0
327328
// use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5
328-
use_sqrt_for_pow =
329-
" else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
329+
use_pow_shortcut =
330+
" else if (b == 2.0) {\n"
331+
" return a * a;\n"
332+
" } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n"
330333
" return sqrt(a);\n"
331334
" }\n";
332335
}
@@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
337340
" } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n"
338341
" return input_a_element_t(pow(f32(a), b)); // NaN\n"
339342
" }\n"
340-
<< use_sqrt_for_pow
343+
<< use_pow_shortcut
341344
<< " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n"
342345
<< "}\n"
343346
"fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n"

0 commit comments

Comments
 (0)