@@ -2812,33 +2812,16 @@ static enum xnn_status optimize_common_subgraphs_static_reshapes(
28122812 return xnn_status_success ;
28132813 }
28142814
2815- // Set the shape of the static-shaped value.
2816- struct xnn_shape new_shape ;
2817- if (node -> type == xnn_node_type_static_reshape ) {
2818- // Replace the old shape with the new shape, filling any gaps from the input
2819- // shape.
2820- new_shape = node -> params .static_reshape .new_shape ;
2821- XNN_RETURN_IF_ERROR (xnn_shape_fill_gaps (& input_value -> shape , & new_shape ));
2822- } else if (node -> type == xnn_node_type_static_expand_dims ) {
2823- const struct xnn_shape * new_dims = & node -> params .static_reshape .new_shape ;
2824- new_shape .num_dims = input_value -> shape .num_dims + new_dims -> num_dims ;
2825- for (uint32_t idx_new = 0 , idx_old = 0 , k = 0 ; k < new_shape .num_dims ;
2826- k ++ ) {
2827- if (idx_new < new_dims -> num_dims && new_dims -> dim [idx_new ] == k ) {
2828- new_shape .dim [k ] = 1 ;
2829- idx_new ++ ;
2830- } else {
2831- new_shape .dim [k ] = input_value -> shape .dim [idx_old ++ ];
2832- }
2833- }
2834- }
2835-
28362815 // If the input is a static value, apply the new shape to it directly.
28372816 bool elide = true;
28382817 if (xnn_value_is_static (input_value -> allocation_type )) {
2839- input_value -> shape = new_shape ;
2840- } else {
2841- elide = xnn_shape_match (& new_shape , & input_value -> shape );
2818+ input_value -> shape = output_value -> shape ;
2819+ }
2820+
2821+ // Otherwise, if the new shape is the old shape, do away with the reshape
2822+ // entirely.
2823+ else {
2824+ elide = xnn_shape_match (& input_value -> shape , & output_value -> shape );
28422825 }
28432826
28442827 if (elide ) {
@@ -3427,6 +3410,12 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34273410 }
34283411 }
34293412
3413+ // If we don't know the shape of the constant, then we can't really guarantee
3414+ // anything.
3415+ if ((const_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC ) == 0 ) {
3416+ return xnn_status_success ;
3417+ }
3418+
34303419 const enum xnn_binary_operator binary_operator = node -> binary_operator ;
34313420 const bool const_is_zero = (const_value -> flags & XNN_VALUE_FLAG_IS_ZERO ) != 0 ;
34323421 const bool const_is_one = (const_value -> flags & XNN_VALUE_FLAG_IS_ONE ) != 0 ;
@@ -3453,32 +3442,29 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34533442 (const_is_zero &&
34543443 (binary_operator == xnn_binary_add ||
34553444 (binary_operator == xnn_binary_subtract && const_is_rhs )))) {
3456- if (short_circuit (subgraph , input_value -> id , node -> outputs [0 ])) {
3457- xnn_log_info ("Elided spurious %s[#%u](v%03u, %s)." ,
3458- binary_operator == xnn_binary_multiply ? "mul"
3459- : binary_operator == xnn_binary_divide ? "div"
3460- : binary_operator == xnn_binary_add ? "add"
3461- : "sub" ,
3462- node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" );
3463- xnn_node_clear (node );
3464- (* changes )++ ;
3465- } else if (input_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC &&
3466- const_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC ) {
3467- // If this node cannot be elided, and both input shapes are static, then
3468- // try to replace it with a `copy` or `broadcast` of the input value.
3469- struct xnn_shape * output_shape = & subgraph -> values [node -> outputs [0 ]].shape ;
3470- XNN_RETURN_IF_ERROR (
3471- xnn_shape_binary_broadcast (& input_value -> shape , & const_value -> shape ,
3472- output_shape ),
3473- "Incompatible input shapes for %s[#%u](v%03u, %s)." ,
3474- binary_operator == xnn_binary_multiply ? "mul"
3475- : binary_operator == xnn_binary_divide ? "div"
3476- : binary_operator == xnn_binary_add ? "add"
3477- : "sub" ,
3478- node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" );
3479-
3480- // If the output shape matches the input shape, just copy the input.
3481- if (xnn_shape_match (& input_value -> shape , output_shape )) {
3445+ // We can safely elide this operation if we know it will not change the
3446+ // shape of the output, e.g. if the constant is a scalar or the shapes are
3447+ // static and equal.
3448+ if ((xnn_shape_multiply_all_dims (& const_value -> shape ) == 1 &&
3449+ const_value -> shape .num_dims <= input_value -> shape .num_dims ) ||
3450+ ((input_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC ) &&
3451+ xnn_shape_match (& input_value -> shape , & output_value -> shape ))) {
3452+ // If the node be elided (not load-bearing), then just remove it.
3453+ if (short_circuit (subgraph , input_value -> id , node -> outputs [0 ])) {
3454+ xnn_log_info ("Elided spurious %s[#%u](v%03u, %s)." ,
3455+ binary_operator == xnn_binary_multiply ? "mul"
3456+ : binary_operator == xnn_binary_divide ? "div"
3457+ : binary_operator == xnn_binary_add ? "add"
3458+ : "sub" ,
3459+ node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" );
3460+ xnn_node_clear (node );
3461+ (* changes )++ ;
3462+ }
3463+
3464+ // Otherwise, replace it with a copy.
3465+ else {
3466+ // If the constant is a scalar, then it won't affect the shape of the
3467+ // output.
34823468 XNN_RETURN_IF_ERROR (xnn_define_copy (subgraph , input_value -> id ,
34833469 node -> outputs [0 ], node -> flags ),
34843470 "Failed to create new `Copy` node." );
@@ -3492,26 +3478,6 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34923478 node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" , node_id ,
34933479 input_value -> id );
34943480 }
3495-
3496- // Otherwise, we need to broadcast the input to the output shape.
3497- else {
3498- XNN_RETURN_IF_ERROR (
3499- xnn_define_static_broadcast (subgraph , output_shape -> num_dims ,
3500- output_shape -> dim , input_value -> id ,
3501- node -> outputs [0 ],
3502- node -> flags ),
3503- "Failed to create new `Broadcast` node." );
3504- node = move_last_node_to (subgraph , node_id );
3505- xnn_log_info (
3506- "Replaced spurious %s[#%u](v%03u, %s) with "
3507- "static_broadcast[#%u](v%03u)." ,
3508- binary_operator == xnn_binary_multiply ? "mul"
3509- : binary_operator == xnn_binary_divide ? "div"
3510- : binary_operator == xnn_binary_add ? "add"
3511- : "sub" ,
3512- node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" , node_id ,
3513- input_value -> id );
3514- }
35153481 (* changes )++ ;
35163482 }
35173483 }
@@ -3637,9 +3603,55 @@ static enum xnn_status optimize_common_subgraphs_iter(
36373603 XNN_VALUE_FLAG_SHAPE_IS_STATIC ) != 0 ;
36383604 }
36393605 if (all_input_shapes_are_static ) {
3640- for (int k = 0 ; k < node -> num_outputs ; k ++ ) {
3641- subgraph -> values [node -> outputs [k ]].flags |=
3642- XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3606+ switch (node -> type ) {
3607+ case xnn_node_type_unary_elementwise :
3608+ subgraph -> values [node -> outputs [0 ]].shape =
3609+ subgraph -> values [node -> inputs [0 ]].shape ;
3610+ subgraph -> values [node -> outputs [0 ]].flags |=
3611+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3612+ break ;
3613+
3614+ case xnn_node_type_binary_elementwise :
3615+ xnn_shape_binary_broadcast (& subgraph -> values [node -> inputs [0 ]].shape ,
3616+ & subgraph -> values [node -> inputs [1 ]].shape ,
3617+ & subgraph -> values [node -> outputs [0 ]].shape );
3618+ subgraph -> values [node -> outputs [0 ]].flags |=
3619+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3620+ break ;
3621+
3622+ case xnn_node_type_static_transpose :
3623+ // Apply the transpose to the output shape.
3624+ for (int k = 0 ; k < node -> params .transpose .num_dims ; k ++ ) {
3625+ subgraph -> values [node -> outputs [0 ]].shape .dim [k ] =
3626+ subgraph -> values [node -> inputs [0 ]]
3627+ .shape .dim [node -> params .transpose .perm [k ]];
3628+ }
3629+ subgraph -> values [node -> outputs [0 ]].flags |=
3630+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3631+ break ;
3632+
3633+ case xnn_node_type_static_expand_dims : {
3634+ const struct xnn_shape * new_dims =
3635+ & node -> params .static_reshape .new_shape ;
3636+ subgraph -> values [node -> outputs [0 ]].shape .num_dims =
3637+ subgraph -> values [node -> inputs [0 ]].shape .num_dims +
3638+ new_dims -> num_dims ;
3639+ for (uint32_t idx_new = 0 , idx_in = 0 , k = 0 ;
3640+ k < subgraph -> values [node -> outputs [0 ]].shape .num_dims ; k ++ ) {
3641+ if (idx_new < new_dims -> num_dims && new_dims -> dim [idx_new ] == k ) {
3642+ subgraph -> values [node -> outputs [0 ]].shape .dim [k ] = 1 ;
3643+ idx_new ++ ;
3644+ } else {
3645+ subgraph -> values [node -> outputs [0 ]].shape .dim [k ] =
3646+ subgraph -> values [node -> inputs [0 ]].shape .dim [idx_in ++ ];
3647+ }
3648+ }
3649+ subgraph -> values [node -> outputs [0 ]].flags |=
3650+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3651+ } break ;
3652+
3653+ default :
3654+ break ;
36433655 }
36443656 }
36453657
0 commit comments