Skip to content

Commit 4065278

Browse files
vksnkxnnpack-bot
authored andcommitted
Allow specifying a dimension order in ynn_runtime::make_schedule.
This enables reordering loops before scheduling, so when the splits and number of workers are computed this can be taken into account. PiperOrigin-RevId: 831653398
1 parent decc685 commit 4065278

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

ynnpack/subgraph/dot.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstdint>
1414
#include <cstring>
1515
#include <memory>
16+
#include <numeric>
1617
#include <optional>
1718
#include <tuple>
1819
#include <utility>
@@ -1036,23 +1037,23 @@ ynn_status ynn_define_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
10361037
split_n = {};
10371038
}
10381039

1040+
std::vector<slinky::index_t> loop_order = {0, 1};
1041+
if (pack_b && !packed_b.is_static()) {
1042+
// Loop over n first so we don't redundantly compute the packing for each
1043+
// split of m.
1044+
std::swap(loop_order[0], loop_order[1]);
1045+
}
1046+
10391047
slinky::expr splits[] = {split_n, split_m};
1040-
auto sched =
1041-
runtime.make_schedule(dims, output.buffer, node.outputs[0], splits);
1048+
auto sched = runtime.make_schedule(dims, output.buffer, node.outputs[0],
1049+
splits, 1, loop_order);
10421050

10431051
// We want to use exactly these loop splits for two innermost dot loops.
10441052
for (size_t i = 0; i < std::min<std::size_t>(2, sched->loop_splits.size());
10451053
i++) {
10461054
sched->loop_splits[i].step_is_required = true;
10471055
}
10481056

1049-
if (pack_b && !packed_b.is_static() && sched->loop_splits.size() >= 2 &&
1050-
sched->loop_splits[1].axis == 1) {
1051-
// Loop over n first so we don't redundantly compute the packing for each
1052-
// split of m.
1053-
std::swap(sched->loop_splits[0], sched->loop_splits[1]);
1054-
}
1055-
10561057
// Schedule the output buffer to be stored at the same level as it's
10571058
// computed at.
10581059
ynn::scheduled_buffer sched_output_buffer = {output.buffer, 0};

ynnpack/subgraph/runtime.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ slinky::var ynn_runtime::make_global_variable(slinky::expr value,
7979
std::unique_ptr<ynn::scheduling_info> ynn_runtime::make_schedule(
8080
const std::vector<slinky::var>& dims, const slinky::buffer_expr_ptr output,
8181
uint32_t output_value, slinky::span<const slinky::expr> given_splits,
82-
const slinky::expr& element_cost) {
82+
const slinky::expr& element_cost,
83+
const std::vector<slinky::index_t>& loop_order) {
8384
auto sched = std::make_unique<ynn::scheduling_info>();
8485

8586
int max_threads = threadpool() ? threadpool()->thread_count() : 1;
@@ -104,7 +105,13 @@ std::unique_ptr<ynn::scheduling_info> ynn_runtime::make_schedule(
104105
output->elem_size() * element_cost);
105106
std::vector<slinky::expr> splits(rank);
106107
slinky::expr tile_area_so_far = 1;
107-
for (int d = 0; d < rank; ++d) {
108+
109+
auto get_loop_dim = [&](int index_d) {
110+
return index_d < loop_order.size() ? loop_order[index_d] : index_d;
111+
};
112+
113+
for (int index_d = 0; index_d < rank; ++index_d) {
114+
int d = get_loop_dim(index_d);
108115
if (!output_extents[d].defined()) continue;
109116
if (d < given_splits.size()) {
110117
splits[d] = given_splits[d];
@@ -124,7 +131,8 @@ std::unique_ptr<ynn::scheduling_info> ynn_runtime::make_schedule(
124131
std::vector<slinky::expr> workers(rank);
125132
slinky::expr threads_so_far = 1;
126133

127-
for (int d = rank - 1; d >= 0; --d) {
134+
for (int index_d = rank - 1; index_d >= 0; --index_d) {
135+
int d = get_loop_dim(index_d);
128136
if (max_threads == 1) {
129137
workers[d] = slinky::loop::serial;
130138
} else if (output_extents[d].defined() && splits[d].defined()) {
@@ -140,7 +148,8 @@ std::unique_ptr<ynn::scheduling_info> ynn_runtime::make_schedule(
140148
}
141149
}
142150

143-
for (int d = 0; d < rank; ++d) {
151+
for (int index_d = 0; index_d < rank; ++index_d) {
152+
int d = get_loop_dim(index_d);
144153
if (output_extents[d].defined() && splits[d].defined()) {
145154
sched->loop_splits.push_back({dims[d], splits[d], workers[d], d});
146155
}

ynnpack/subgraph/runtime.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ struct ynn_runtime {
7777
std::unique_ptr<ynn::scheduling_info> make_schedule(
7878
const std::vector<slinky::var>& dims, slinky::buffer_expr_ptr output,
7979
uint32_t output_value, slinky::span<const slinky::expr> given_splits = {},
80-
const slinky::expr& element_cost = 1);
80+
const slinky::expr& element_cost = 1,
81+
const std::vector<slinky::index_t>& loop_order = {});
8182

8283
slinky::buffer_expr_ptr null_buffer();
8384

0 commit comments

Comments
 (0)