Skip to content

Commit 3ceb846

Browse files
committed
Optimize handling of slinky buffer objects
PiperOrigin-RevId: 836039592
1 parent 43f9c2f commit 3ceb846

File tree

1 file changed

+50
-25
lines changed

1 file changed

+50
-25
lines changed

ynnpack/subgraph/elementwise.cc

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,25 @@ auto make_unary_elementwise_impl(unary_kernel_fn kernel) {
137137
};
138138
}
139139

140+
inline const slinky::dim& dim_or_broadcast(const slinky::raw_buffer& buf,
141+
std::ptrdiff_t d) {
142+
return d < static_cast<std::ptrdiff_t>(buf.rank) ? buf.dim(d)
143+
: slinky::dim::broadcast();
144+
}
145+
146+
inline bool same_bounds(const slinky::dim& a, const slinky::dim& b) {
147+
// Return true if the dimensions have the same min and max or if they are
148+
// both broadcasts.
149+
return (a.min() == b.min() && a.max() == b.max()) ||
150+
(a.stride() == 0 && b.stride() == 0);
151+
}
152+
153+
template <typename... Dims>
154+
inline bool same_bounds(const slinky::dim& a, const slinky::dim& b,
155+
const Dims&... dims) {
156+
return same_bounds(a, b) && same_bounds(b, dims...);
157+
}
158+
140159
// Binary kernels only support a single global params object, i.e. it must be
141160
// globally broadcasted. Currently, the only operation that needs to support
142161
// non-scalar params is `convert` with non-scalar quantization data.
@@ -148,33 +167,39 @@ auto make_binary_elementwise_impl(binary_kernel_fn kernel,
148167
slinky::buffer<const void, YNN_MAX_TENSOR_RANK> a,
149168
slinky::buffer<const void, YNN_MAX_TENSOR_RANK> b,
150169
slinky::buffer<void, YNN_MAX_TENSOR_RANK> x) -> slinky::index_t {
151-
// Try to fuse dimensions where possible.
152-
slinky::optimize_dims(x, a, b);
170+
slinky::dim a_dims[2], b_dims[2], x_dims[2];
171+
for (int i = 0; i < 2; ++i) {
172+
a_dims[i] = dim_or_broadcast(a, 0);
173+
b_dims[i] = dim_or_broadcast(b, 0);
174+
x_dims[i] = dim_or_broadcast(x, 0);
175+
176+
// `x` is already a view to the correct tile in the larger output buffer.
177+
// Inputs `a` and `b` are not. We must explicitly set their offsets
178+
// according to `x` before slicing.
179+
if (a.rank > 0) a.slice(0, x_dims[i].min());
180+
if (b.rank > 0) b.slice(0, x_dims[i].min());
181+
if (x.rank > 0) x.slice(0);
153182

154-
// We're going to handle the two innermost dimensions with the kernel, or
155-
// treat them as broadcasts if there aren't two dimensions.
156-
const slinky::dim broadcast(0, 0, 0, 0);
183+
while (x.rank > 0 && same_bounds(x_dims[i], a_dims[i], b_dims[i]) &&
184+
slinky::can_fuse(x_dims[i], x.dim(0)) &&
185+
slinky::can_fuse(a_dims[i], dim_or_broadcast(a, 0)) &&
186+
slinky::can_fuse(b_dims[i], dim_or_broadcast(b, 0))) {
187+
a_dims[i] = slinky::fuse(a_dims[i], dim_or_broadcast(a, 0));
188+
b_dims[i] = slinky::fuse(b_dims[i], dim_or_broadcast(b, 0));
189+
x_dims[i] = slinky::fuse(x_dims[i], x.dim(0));
190+
191+
if (a.rank > 0) a.slice(0);
192+
if (b.rank > 0) b.slice(0);
193+
x.slice(0);
194+
}
195+
}
157196

158-
const slinky::dim& a_n = a.rank > 0 ? a.dim(0) : broadcast;
159-
const slinky::dim& b_n = b.rank > 0 ? b.dim(0) : broadcast;
160-
const slinky::dim& x_n = x.rank > 0 ? x.dim(0) : broadcast;
161-
const slinky::dim& a_m = a.rank > 1 ? a.dim(1) : broadcast;
162-
const slinky::dim& b_m = b.rank > 1 ? b.dim(1) : broadcast;
163-
const slinky::dim& x_m = x.rank > 1 ? x.dim(1) : broadcast;
164-
165-
assert(!a_n.is_folded(x_n));
166-
assert(!b_n.is_folded(x_n));
167-
assert(!x_n.is_folded());
168-
assert(!a_m.is_folded(x_m));
169-
assert(!b_m.is_folded(x_m));
170-
assert(!x_m.is_folded());
171-
172-
if (a.rank > 0) a.slice(0, x.dim(0).min());
173-
if (b.rank > 0) b.slice(0, x.dim(0).min());
174-
if (x.rank > 0) x.slice(0);
175-
if (a.rank > 0) a.slice(0, x.dim(0).min());
176-
if (b.rank > 0) b.slice(0, x.dim(0).min());
177-
if (x.rank > 0) x.slice(0);
197+
const slinky::dim& a_n = a_dims[0];
198+
const slinky::dim& b_n = b_dims[0];
199+
const slinky::dim& x_n = x_dims[0];
200+
const slinky::dim& a_m = a_dims[1];
201+
const slinky::dim& b_m = b_dims[1];
202+
const slinky::dim& x_m = x_dims[1];
178203

179204
slinky::for_each_element(
180205
[&](void* x, const void* a, const void* b) {

0 commit comments

Comments
 (0)