@@ -137,6 +137,25 @@ auto make_unary_elementwise_impl(unary_kernel_fn kernel) {
137137 };
138138}
139139
140+ inline const slinky::dim& dim_or_broadcast (const slinky::raw_buffer& buf,
141+ std::ptrdiff_t d) {
142+ return d < static_cast <std::ptrdiff_t >(buf.rank ) ? buf.dim (d)
143+ : slinky::dim::broadcast ();
144+ }
145+
146+ inline bool same_bounds (const slinky::dim& a, const slinky::dim& b) {
147+ // Return true if the dimensions have the same min and max or if they are
148+ // both broadcasts.
149+ return (a.min () == b.min () && a.max () == b.max ()) ||
150+ (a.stride () == 0 && b.stride () == 0 );
151+ }
152+
153+ template <typename ... Dims>
154+ inline bool same_bounds (const slinky::dim& a, const slinky::dim& b,
155+ const Dims&... dims) {
156+ return same_bounds (a, b) && same_bounds (b, dims...);
157+ }
158+
140159// Binary kernels only support a single global params object, i.e. it must be
141160// globally broadcasted. Currently, the only operation that needs to support
142161// non-scalar params is `convert` with non-scalar quantization data.
@@ -148,33 +167,39 @@ auto make_binary_elementwise_impl(binary_kernel_fn kernel,
148167 slinky::buffer<const void , YNN_MAX_TENSOR_RANK> a,
149168 slinky::buffer<const void , YNN_MAX_TENSOR_RANK> b,
150169 slinky::buffer<void , YNN_MAX_TENSOR_RANK> x) -> slinky::index_t {
151- // Try to fuse dimensions where possible.
152- slinky::optimize_dims (x, a, b);
170+ slinky::dim a_dims[2 ], b_dims[2 ], x_dims[2 ];
171+ for (int i = 0 ; i < 2 ; ++i) {
172+ a_dims[i] = dim_or_broadcast (a, 0 );
173+ b_dims[i] = dim_or_broadcast (b, 0 );
174+ x_dims[i] = dim_or_broadcast (x, 0 );
175+
176+ // `x` is already a view to the correct tile in the larger output buffer.
177+ // Inputs `a` and `b` are not. We must explicitly set their offsets
178+ // according to `x` before slicing.
179+ if (a.rank > 0 ) a.slice (0 , x_dims[i].min ());
180+ if (b.rank > 0 ) b.slice (0 , x_dims[i].min ());
181+ if (x.rank > 0 ) x.slice (0 );
153182
154- // We're going to handle the two innermost dimensions with the kernel, or
155- // treat them as broadcasts if there aren't two dimensions.
156- const slinky::dim broadcast (0 , 0 , 0 , 0 );
183+ while (x.rank > 0 && same_bounds (x_dims[i], a_dims[i], b_dims[i]) &&
184+ slinky::can_fuse (x_dims[i], x.dim (0 )) &&
185+ slinky::can_fuse (a_dims[i], dim_or_broadcast (a, 0 )) &&
186+ slinky::can_fuse (b_dims[i], dim_or_broadcast (b, 0 ))) {
187+ a_dims[i] = slinky::fuse (a_dims[i], dim_or_broadcast (a, 0 ));
188+ b_dims[i] = slinky::fuse (b_dims[i], dim_or_broadcast (b, 0 ));
189+ x_dims[i] = slinky::fuse (x_dims[i], x.dim (0 ));
190+
191+ if (a.rank > 0 ) a.slice (0 );
192+ if (b.rank > 0 ) b.slice (0 );
193+ x.slice (0 );
194+ }
195+ }
157196
158- const slinky::dim& a_n = a.rank > 0 ? a.dim (0 ) : broadcast;
159- const slinky::dim& b_n = b.rank > 0 ? b.dim (0 ) : broadcast;
160- const slinky::dim& x_n = x.rank > 0 ? x.dim (0 ) : broadcast;
161- const slinky::dim& a_m = a.rank > 1 ? a.dim (1 ) : broadcast;
162- const slinky::dim& b_m = b.rank > 1 ? b.dim (1 ) : broadcast;
163- const slinky::dim& x_m = x.rank > 1 ? x.dim (1 ) : broadcast;
164-
165- assert (!a_n.is_folded (x_n));
166- assert (!b_n.is_folded (x_n));
167- assert (!x_n.is_folded ());
168- assert (!a_m.is_folded (x_m));
169- assert (!b_m.is_folded (x_m));
170- assert (!x_m.is_folded ());
171-
172- if (a.rank > 0 ) a.slice (0 , x.dim (0 ).min ());
173- if (b.rank > 0 ) b.slice (0 , x.dim (0 ).min ());
174- if (x.rank > 0 ) x.slice (0 );
175- if (a.rank > 0 ) a.slice (0 , x.dim (0 ).min ());
176- if (b.rank > 0 ) b.slice (0 , x.dim (0 ).min ());
177- if (x.rank > 0 ) x.slice (0 );
197+ const slinky::dim& a_n = a_dims[0 ];
198+ const slinky::dim& b_n = b_dims[0 ];
199+ const slinky::dim& x_n = x_dims[0 ];
200+ const slinky::dim& a_m = a_dims[1 ];
201+ const slinky::dim& b_m = b_dims[1 ];
202+ const slinky::dim& x_m = x_dims[1 ];
178203
179204 slinky::for_each_element (
180205 [&](void * x, const void * a, const void * b) {
0 commit comments