Skip to content

Commit 658b15f

Browse files
gonnetxnnpack-bot
authored andcommitted
Be more strict with elided arithmetic op shapes.
PiperOrigin-RevId: 831383616
1 parent 4065278 commit 658b15f

File tree

2 files changed

+91
-80
lines changed

2 files changed

+91
-80
lines changed

src/subgraph.c

Lines changed: 85 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,33 +2812,16 @@ static enum xnn_status optimize_common_subgraphs_static_reshapes(
28122812
return xnn_status_success;
28132813
}
28142814

2815-
// Set the shape of the static-shaped value.
2816-
struct xnn_shape new_shape;
2817-
if (node->type == xnn_node_type_static_reshape) {
2818-
// Replace the old shape with the new shape, filling any gaps from the input
2819-
// shape.
2820-
new_shape = node->params.static_reshape.new_shape;
2821-
XNN_RETURN_IF_ERROR(xnn_shape_fill_gaps(&input_value->shape, &new_shape));
2822-
} else if (node->type == xnn_node_type_static_expand_dims) {
2823-
const struct xnn_shape* new_dims = &node->params.static_reshape.new_shape;
2824-
new_shape.num_dims = input_value->shape.num_dims + new_dims->num_dims;
2825-
for (uint32_t idx_new = 0, idx_old = 0, k = 0; k < new_shape.num_dims;
2826-
k++) {
2827-
if (idx_new < new_dims->num_dims && new_dims->dim[idx_new] == k) {
2828-
new_shape.dim[k] = 1;
2829-
idx_new++;
2830-
} else {
2831-
new_shape.dim[k] = input_value->shape.dim[idx_old++];
2832-
}
2833-
}
2834-
}
2835-
28362815
// If the input is a static value, apply the new shape to it directly.
28372816
bool elide = true;
28382817
if (xnn_value_is_static(input_value->allocation_type)) {
2839-
input_value->shape = new_shape;
2840-
} else {
2841-
elide = xnn_shape_match(&new_shape, &input_value->shape);
2818+
input_value->shape = output_value->shape;
2819+
}
2820+
2821+
// Otherwise, if the new shape is the old shape, do away with the reshape
2822+
// entirely.
2823+
else {
2824+
elide = xnn_shape_match(&input_value->shape, &output_value->shape);
28422825
}
28432826

28442827
if (elide) {
@@ -3427,6 +3410,12 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34273410
}
34283411
}
34293412

3413+
// If we don't know the shape of the constant, then we can't really guarantee
3414+
// anything.
3415+
if ((const_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC) == 0) {
3416+
return xnn_status_success;
3417+
}
3418+
34303419
const enum xnn_binary_operator binary_operator = node->binary_operator;
34313420
const bool const_is_zero = (const_value->flags & XNN_VALUE_FLAG_IS_ZERO) != 0;
34323421
const bool const_is_one = (const_value->flags & XNN_VALUE_FLAG_IS_ONE) != 0;
@@ -3453,32 +3442,29 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34533442
(const_is_zero &&
34543443
(binary_operator == xnn_binary_add ||
34553444
(binary_operator == xnn_binary_subtract && const_is_rhs)))) {
3456-
if (short_circuit(subgraph, input_value->id, node->outputs[0])) {
3457-
xnn_log_info("Elided spurious %s[#%u](v%03u, %s).",
3458-
binary_operator == xnn_binary_multiply ? "mul"
3459-
: binary_operator == xnn_binary_divide ? "div"
3460-
: binary_operator == xnn_binary_add ? "add"
3461-
: "sub",
3462-
node->id, input_value->id, const_is_zero ? "0.0" : "1.0");
3463-
xnn_node_clear(node);
3464-
(*changes)++;
3465-
} else if (input_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC &&
3466-
const_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC) {
3467-
// If this node cannot be elided, and both input shapes are static, then
3468-
// try to replace it with a `copy` or `broadcast` of the input value.
3469-
struct xnn_shape* output_shape = &subgraph->values[node->outputs[0]].shape;
3470-
XNN_RETURN_IF_ERROR(
3471-
xnn_shape_binary_broadcast(&input_value->shape, &const_value->shape,
3472-
output_shape),
3473-
"Incompatible input shapes for %s[#%u](v%03u, %s).",
3474-
binary_operator == xnn_binary_multiply ? "mul"
3475-
: binary_operator == xnn_binary_divide ? "div"
3476-
: binary_operator == xnn_binary_add ? "add"
3477-
: "sub",
3478-
node->id, input_value->id, const_is_zero ? "0.0" : "1.0");
3479-
3480-
// If the output shape matches the input shape, just copy the input.
3481-
if (xnn_shape_match(&input_value->shape, output_shape)) {
3445+
// We can safely elide this operation if we know it will not change the
3446+
// shape of the output, e.g. if the constant is a scalar or the shapes are
3447+
// static and equal.
3448+
if ((xnn_shape_multiply_all_dims(&const_value->shape) == 1 &&
3449+
const_value->shape.num_dims <= input_value->shape.num_dims) ||
3450+
((input_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC) &&
3451+
xnn_shape_match(&input_value->shape, &output_value->shape))) {
3452+
// If the node be elided (not load-bearing), then just remove it.
3453+
if (short_circuit(subgraph, input_value->id, node->outputs[0])) {
3454+
xnn_log_info("Elided spurious %s[#%u](v%03u, %s).",
3455+
binary_operator == xnn_binary_multiply ? "mul"
3456+
: binary_operator == xnn_binary_divide ? "div"
3457+
: binary_operator == xnn_binary_add ? "add"
3458+
: "sub",
3459+
node->id, input_value->id, const_is_zero ? "0.0" : "1.0");
3460+
xnn_node_clear(node);
3461+
(*changes)++;
3462+
}
3463+
3464+
// Otherwise, replace it with a copy.
3465+
else {
3466+
// If the constant is a scalar, then it won't affect the shape of the
3467+
// output.
34823468
XNN_RETURN_IF_ERROR(xnn_define_copy(subgraph, input_value->id,
34833469
node->outputs[0], node->flags),
34843470
"Failed to create new `Copy` node.");
@@ -3492,26 +3478,6 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34923478
node->id, input_value->id, const_is_zero ? "0.0" : "1.0", node_id,
34933479
input_value->id);
34943480
}
3495-
3496-
// Otherwise, we need to broadcast the input to the output shape.
3497-
else {
3498-
XNN_RETURN_IF_ERROR(
3499-
xnn_define_static_broadcast(subgraph, output_shape->num_dims,
3500-
output_shape->dim, input_value->id,
3501-
node->outputs[0],
3502-
node->flags),
3503-
"Failed to create new `Broadcast` node.");
3504-
node = move_last_node_to(subgraph, node_id);
3505-
xnn_log_info(
3506-
"Replaced spurious %s[#%u](v%03u, %s) with "
3507-
"static_broadcast[#%u](v%03u).",
3508-
binary_operator == xnn_binary_multiply ? "mul"
3509-
: binary_operator == xnn_binary_divide ? "div"
3510-
: binary_operator == xnn_binary_add ? "add"
3511-
: "sub",
3512-
node->id, input_value->id, const_is_zero ? "0.0" : "1.0", node_id,
3513-
input_value->id);
3514-
}
35153481
(*changes)++;
35163482
}
35173483
}
@@ -3637,9 +3603,55 @@ static enum xnn_status optimize_common_subgraphs_iter(
36373603
XNN_VALUE_FLAG_SHAPE_IS_STATIC) != 0;
36383604
}
36393605
if (all_input_shapes_are_static) {
3640-
for (int k = 0; k < node->num_outputs; k++) {
3641-
subgraph->values[node->outputs[k]].flags |=
3642-
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3606+
switch (node->type) {
3607+
case xnn_node_type_unary_elementwise:
3608+
subgraph->values[node->outputs[0]].shape =
3609+
subgraph->values[node->inputs[0]].shape;
3610+
subgraph->values[node->outputs[0]].flags |=
3611+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3612+
break;
3613+
3614+
case xnn_node_type_binary_elementwise:
3615+
xnn_shape_binary_broadcast(&subgraph->values[node->inputs[0]].shape,
3616+
&subgraph->values[node->inputs[1]].shape,
3617+
&subgraph->values[node->outputs[0]].shape);
3618+
subgraph->values[node->outputs[0]].flags |=
3619+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3620+
break;
3621+
3622+
case xnn_node_type_static_transpose:
3623+
// Apply the transpose to the output shape.
3624+
for (int k = 0; k < node->params.transpose.num_dims; k++) {
3625+
subgraph->values[node->outputs[0]].shape.dim[k] =
3626+
subgraph->values[node->inputs[0]]
3627+
.shape.dim[node->params.transpose.perm[k]];
3628+
}
3629+
subgraph->values[node->outputs[0]].flags |=
3630+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3631+
break;
3632+
3633+
case xnn_node_type_static_expand_dims: {
3634+
const struct xnn_shape* new_dims =
3635+
&node->params.static_reshape.new_shape;
3636+
subgraph->values[node->outputs[0]].shape.num_dims =
3637+
subgraph->values[node->inputs[0]].shape.num_dims +
3638+
new_dims->num_dims;
3639+
for (uint32_t idx_new = 0, idx_in = 0, k = 0;
3640+
k < subgraph->values[node->outputs[0]].shape.num_dims; k++) {
3641+
if (idx_new < new_dims->num_dims && new_dims->dim[idx_new] == k) {
3642+
subgraph->values[node->outputs[0]].shape.dim[k] = 1;
3643+
idx_new++;
3644+
} else {
3645+
subgraph->values[node->outputs[0]].shape.dim[k] =
3646+
subgraph->values[node->inputs[0]].shape.dim[idx_in++];
3647+
}
3648+
}
3649+
subgraph->values[node->outputs[0]].flags |=
3650+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3651+
} break;
3652+
3653+
default:
3654+
break;
36433655
}
36443656
}
36453657

test/subgraph/rewrites.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
#include <gmock/gmock.h>
2121
#include <gtest/gtest.h>
22-
#include "include/experimental.h"
2322
#include "include/xnnpack.h"
2423
#include "src/subgraph/subgraph-utils.h"
2524
#include "src/xnnpack/buffer.h"
@@ -844,7 +843,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeMul) {
844843
// Add a scalar static tensor with the value `1.0`.
845844
uint32_t static_one_value_id;
846845
std::tie(static_one_tensor, static_one_value_id) =
847-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 1.0, 1.0);
846+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 1.0, 1.0);
848847

849848
// Add the binary `multiply` op with the constant 1.0.
850849
auto inputs =
@@ -934,7 +933,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeDiv) {
934933
// Add a scalar static tensor with the value `1.0`.
935934
uint32_t static_one_value_id;
936935
std::tie(static_one_tensor, static_one_value_id) =
937-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 1.0, 1.0);
936+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 1.0, 1.0);
938937

939938
// Add the binary `divide` op by the constant 1.0.
940939
subgraph.AddBinary(xnn_binary_divide, /*params=*/nullptr,
@@ -1022,7 +1021,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeAdd) {
10221021
// Add a scalar static tensor with the value `0.0`.
10231022
uint32_t static_zero_value_id;
10241023
std::tie(static_zero_tensor, static_zero_value_id) =
1025-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 0.0, 0.0);
1024+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 0.0, 0.0);
10261025

10271026
// Add the binary `add` op with the constant 0.0.
10281027
auto inputs =
@@ -1112,7 +1111,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeSub) {
11121111
// Add a scalar static tensor with the value `0.0`.
11131112
uint32_t static_zero_value_id;
11141113
std::tie(static_zero_tensor, static_zero_value_id) =
1115-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 0.0, 0.0);
1114+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 0.0, 0.0);
11161115

11171116
// Add the binary `subtract` op with the constant 0.0.
11181117
subgraph.AddBinary(xnn_binary_subtract, /*params=*/nullptr,
@@ -1247,7 +1246,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpChainOfStaticShapeMulZeroAdd) {
12471246
// Add a scalar static tensor with the value `0.0`.
12481247
uint32_t static_zero_value_id;
12491248
std::tie(static_zero_tensor, static_zero_value_id) =
1250-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 0.0, 0.0);
1249+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 0.0, 0.0);
12511250

12521251
// Add the binary `multiply` op with the constant 0.0.
12531252
uint32_t dynamic_zero_value_id =
@@ -1361,7 +1360,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpChainOfStaticShapeDivOneMul) {
13611360
// Add a scalar static tensor with the value `1.0`.
13621361
uint32_t static_one_value_id;
13631362
std::tie(static_one_tensor, static_one_value_id) =
1364-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 1.0, 1.0);
1363+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 1.0, 1.0);
13651364

13661365
// Add the static `1.0` to the absolute value of the inputs to make sure
13671366
// they are non-negative

0 commit comments

Comments
 (0)