diff --git a/src/AssociativeOpsTable.cpp b/src/AssociativeOpsTable.cpp index badc4120372a..a4cfeffda198 100644 --- a/src/AssociativeOpsTable.cpp +++ b/src/AssociativeOpsTable.cpp @@ -203,6 +203,16 @@ void populate_ops_table_double_general_sub(const vector &types, vector &types, vector &table) { declare_vars_double(types); + // Argmax with index as first tuple element + table.push_back({{select(x1 > y1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true}); + table.push_back({{select(x1 >= y1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true}); + table.push_back({{select(y1 < x1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true}); + table.push_back({{select(y1 <= x1, x0, y0), max(x1, y1)}, {zero_0, tmin_1}, true}); + // Argmin with index as first tuple element + table.push_back({{select(x1 < y1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true}); + table.push_back({{select(x1 <= y1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true}); + table.push_back({{select(y1 > x1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true}); + table.push_back({{select(y1 >= x1, x0, y0), min(x1, y1)}, {zero_0, tmax_1}, true}); } void populate_ops_table_single_uint1_and(const vector &types, vector &table) { @@ -326,19 +336,6 @@ const vector &get_ops_table_helper(const vector &types return table_it->second; } -std::string print_types(const vector &types) { - std::ostringstream stream; - stream << "{"; - for (size_t i = 0; i < types.size(); ++i) { - if (i > 0) { - stream << ", "; - } - stream << types[i]; - } - stream << "}"; - return stream.str(); -} - } // anonymous namespace const vector &get_ops_table(const vector &exprs) { diff --git a/src/Associativity.cpp b/src/Associativity.cpp index 6a8d8948be85..bd67f0245af6 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -18,7 +18,6 @@ namespace Halide { namespace Internal { using std::map; -using std::pair; using std::set; using std::string; using std::vector; @@ -103,8 +102,14 @@ bool associative_op_pattern_match(const Expr &e, << "Expr has type " << e.type() << ", while pattern has type " << op.type() << "\n"; map result; if (expr_match(op, e, result)) { - debug(5) << "Found associative ops for " << e << " -> " << op - << ", y_part: " << result["y0"] << "\n"; + debug(5) << "Found associative ops for " << e << " -> " << op << ":\n" + << [&] { + std::stringstream ss; + for (const auto &[var, val] : result) { + ss << " " << var << " -> " << val << "\n"; + } + return ss.str(); + }(); for (size_t i = 0; i < x_names.size(); ++i) { const auto &iter = result.find("x" + std::to_string(i)); @@ -187,7 +192,6 @@ bool find_match(const vector &table, const vector &o continue; } - vector> replacement; // find -> replacement for (size_t index = 0; index < op_y_names.size(); ++index) { const auto &y_iter = pattern_match.find("y" + std::to_string(index)); if (y_iter == pattern_match.end()) { @@ -202,20 +206,25 @@ bool find_match(const vector &table, const vector &o assoc_op.xs[index] = {op_x_names[index], x_parts[index]}; assoc_op.ys[index] = {op_y_names[index], y_part}; - replacement.emplace_back(y_part, Variable::make(y_part.type(), op_y_names[index])); } if (!matched) { continue; } - for (size_t index = 0; index < exprs.size(); ++index) { - Expr e = exprs[index]; - // Order of substitution matters, e.g. in the argmin case, _y_0 -> g(rx)[0] - // and _y_1 -> rx. If we substitute the 2nd element rx first, substitution - // of g(rx)[0] will fail. - for (const auto &iter : replacement) { - e = substitute(iter.first, iter.second, e); + // Build the concrete ops by renaming the pattern's abstract + // wildcard variables (x0, y0, k0, ...) to the actual variable + // names used in the expressions. + map replacement; + for (size_t index = 0; index < op_x_names.size(); ++index) { + replacement["x" + std::to_string(index)] = Variable::make(exprs[index].type(), op_x_names[index]); + replacement["y" + std::to_string(index)] = Variable::make(exprs[index].type(), op_y_names[index]); + } + for (const auto &[wildcard, identity] : pattern_match) { + if (wildcard[0] == 'k') { + replacement[wildcard] = identity; } - assoc_op.pattern.ops[index] = e; + } + for (size_t index = 0; index < pattern.ops.size(); ++index) { + assoc_op.pattern.ops[index] = substitute(replacement, pattern.ops[index]); assoc_op.pattern.identities[index] = pattern.identities[index]; } assoc_op.pattern.is_commutative = pattern.is_commutative; @@ -225,7 +234,7 @@ bool find_match(const vector &table, const vector &o } // Return a pair of booleans indicating if an operator is associative. -// 'assoc_op' contains the the equivalent associative binary/unary operator +// 'assoc_op' contains the equivalent associative binary/unary operator // for that operator. If the operator is non-associative, 'assoc_op' is not valid. bool extract_associative_op(const vector &exprs, const vector &op_x_names, const vector &op_y_names, const vector &x_parts, @@ -236,7 +245,7 @@ bool extract_associative_op(const vector &exprs, const vector &op_ // An update that just assigns some value is not associative, // because there's no good identity. An identity is necessary // because things like rfactor will combine the identity with - // partially-computed values and expect it to do nothing. For an + // partially computed values and expect it to do nothing. For an // example, see https://github.com/halide/Halide/issues/7893 return false; } else if (equal(exprs[0], Variable::make(t, op_x_names[0]))) { @@ -256,58 +265,44 @@ bool extract_associative_op(const vector &exprs, const vector &op_ x_parts, exprs, assoc_op); } -void add_transitive_dependencies(vector> &dependencies) { - // TODO(psuriana): there might be a better way to find all the transitive dependencies - bool change = true; - while (change) { - change = false; +bool is_subset_of(const std::set &a, const std::set &b) { + return std::includes(b.begin(), b.end(), a.begin(), a.end()); +} + +// Compute the dependency subgraphs for a tuple reduction. First closes the +// dependency relation transitively, then returns only the earliest (by index) +// maximal dependency sets, clearing any set contained in a dominating one. +vector> compute_subgraphs(vector> dependencies) { + // Compute the transitive closure using Warshall's algorithm. + for (size_t k = 0; k < dependencies.size(); ++k) { for (size_t i = 0; i < dependencies.size(); ++i) { - for (size_t j = 0; j < dependencies.size(); ++j) { - if (i == j) { - continue; - } - if (dependencies[i].count(j)) { - for (const auto &idx : dependencies[j]) { - if (dependencies[i].count(idx) == 0) { - dependencies[i].insert(idx); - change = true; - } - } + if (dependencies[i].count(k)) { + for (int j : dependencies[k]) { + dependencies[i].insert(j); } } } } -} -// Given dependencies of each tuple element, compute the set of subgraphs: -// all vertices that are reachable from a given vertex. If a subgraph is fully -// contained in another subgraph, remove it from the final output. -vector> compute_subgraphs(vector> dependencies) { + // Keep only maximal dependency sets. A set is removed if another + // set strictly contains it or is identical but has a lower index. vector> subgraphs(dependencies.size()); for (size_t i = 0; i < dependencies.size(); ++i) { - // Check if the current subgraph is a subset of another - const auto ¤t = dependencies[i]; - if (current.empty()) { + if (dependencies[i].empty()) { continue; } - bool should_remove = false; + bool is_maximal = true; for (size_t j = 0; j < dependencies.size(); ++j) { - const auto &other = dependencies[j]; - if ((i == j) || (current.size() > other.size()) || (j < i && subgraphs[i].empty())) { - continue; - } - vector diff; - // Compute the vertices in the current set that are not contained in the other - std::set_difference(current.begin(), current.end(), other.begin(), other.end(), - std::inserter(diff, diff.begin())); - if (diff.empty()) { - // 'current' is fully contained in 'other' - should_remove = true; + const bool can_dominate = + (dependencies[j].size() > dependencies[i].size()) || + (dependencies[j].size() == dependencies[i].size() && j < i); + if (can_dominate && is_subset_of(dependencies[i], dependencies[j])) { + is_maximal = false; break; } } - if (!should_remove) { - subgraphs[i] = current; + if (is_maximal) { + subgraphs[i] = dependencies[i]; } } return subgraphs; @@ -353,8 +348,8 @@ AssociativeOp prove_associativity(const string &f, vector args, vector args, vector> subgraphs; if (!all_independent) { debug(5) << "There are cross-dependencies. Need to prove associativity in bulk.\n"; - // Find all transitive dependencies and add them to the graph - add_transitive_dependencies(dependencies); // Decompose the tuple into subgraphs and solve for each separately subgraphs = compute_subgraphs(dependencies); } else { diff --git a/src/Associativity.h b/src/Associativity.h index e4b7ffd155b5..a28735fca4b5 100644 --- a/src/Associativity.h +++ b/src/Associativity.h @@ -106,7 +106,7 @@ struct AssociativeOp { /** * Given an update definition of a Func 'f', determine its equivalent * associative binary/unary operator if there is any. 'is_associative' - * indicates if the operation was successfuly proven as associative. + * indicates if the operation was successfully proven as associative. */ AssociativeOp prove_associativity( const std::string &f, std::vector args, std::vector exprs); diff --git a/test/correctness/rfactor.cpp b/test/correctness/rfactor.cpp index 30f2b9787e95..e26751f2a0dc 100644 --- a/test/correctness/rfactor.cpp +++ b/test/correctness/rfactor.cpp @@ -831,12 +831,70 @@ int argmin_rfactor_test() { return 0; } +enum class InlineReductionVariant { + ArgMin, + ArgMax, +}; + +template +int inline_reductions_test() { + using namespace ConciseCasts; + constexpr float pi = M_PI; + + Func f{"f"}; + Var x("x"); + f(x) = sin(f32(x) / 8 * pi); // argmax should be f(4) = 1.0, argmin should be f(12) = -10.0 + f.compute_root(); + + RDom r(0, 32); + + Func g{"reduction"}; + Func output{"g"}; + + if constexpr (variant == InlineReductionVariant::ArgMin) { + output() = argmin(f(r), g); + } else { + output() = argmax(f(r), g); + } + + RVar ro("rxo"), ri("rxi"); + g.update(0).split(r, ro, ri, 2); + + Var u("u"); + Func intm = g.update(0).rfactor(ro, u); + intm.compute_root(); + intm.update(0).vectorize(u, 2); + + Realization rn = output.realize(); + Buffer sch_idx(rn[0]); + Buffer sch_val(rn[1]); + + if constexpr (variant == InlineReductionVariant::ArgMin) { + if (sch_val() != -1.0f || sch_idx() != 12) { + fprintf(stderr, "Expected argmin to be f(12) = -1.0, got f(%d) = %f\n", sch_idx(), sch_val()); + return 1; + } + } else { + if (sch_val() != 1.0f || sch_idx() != 4) { + fprintf(stderr, "Expected argmax to be f(4) = 1.0, got f(%d) = %f\n", sch_idx(), sch_val()); + return 1; + } + } + + return 0; +} + enum class ArgMaxVariant { Explicit, TupleSelect }; -template +enum class ArgMaxTupleOrder { + IndexFirst, + ValueFirst, +}; + +template int argmax_rfactor_test() { using namespace ConciseCasts; constexpr float pi = M_PI; @@ -849,12 +907,29 @@ int argmax_rfactor_test() { RDom r(0, 32); Func g{"g"}; - g() = Tuple(f.type().min(), r.x.min()); + + int value_tup = order == ArgMaxTupleOrder::ValueFirst ? 0 : 1; + int index_tup = order == ArgMaxTupleOrder::ValueFirst ? 1 : 0; + + if constexpr (order == ArgMaxTupleOrder::ValueFirst) { + g() = Tuple(f.type().min(), r.x.min()); + } else { + g() = Tuple(r.x.min(), f.type().min()); + } + if constexpr (variant == ArgMaxVariant::Explicit) { - g() = Tuple(max(f(r), g()[0]), select(g()[0] < f(r), r, g()[1])); + if constexpr (order == ArgMaxTupleOrder::ValueFirst) { + g() = Tuple(max(f(r), g()[value_tup]), select(g()[value_tup] < f(r), r, g()[index_tup])); + } else { + g() = Tuple(select(g()[value_tup] < f(r), r, g()[index_tup]), max(f(r), g()[value_tup])); + } } else { static_assert(variant == ArgMaxVariant::TupleSelect); - g() = select(g()[0] < f(r), Tuple(f(r), r), g()); + if constexpr (order == ArgMaxTupleOrder::ValueFirst) { + g() = select(g()[value_tup] < f(r), Tuple(f(r), r), g()); + } else { + g() = select(g()[value_tup] < f(r), Tuple(r, f(r)), g()); + } } RVar ro("rxo"), ri("rxi"); @@ -866,8 +941,8 @@ int argmax_rfactor_test() { intm.update(0).vectorize(u, 2); Realization rn = g.realize(); - Buffer sch_val(rn[0]); - Buffer sch_idx(rn[1]); + Buffer sch_val(rn[value_tup]); + Buffer sch_idx(rn[index_tup]); if (sch_val() != 1.0f || sch_idx() != 4) { fprintf(stderr, "Expected argmax to be f(4) = 1.0, got f(%d) = %f\n", sch_idx(), sch_val()); @@ -1208,8 +1283,12 @@ int main(int argc, char **argv) { {"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test}, {"complex multiply rfactor test", complex_multiply_rfactor_test}, {"argmin rfactor test", argmin_rfactor_test}, - {"argmax rfactor test (explicit)", argmax_rfactor_test}, - {"argmax rfactor test (tuple)", argmax_rfactor_test}, + {"inline reductions test (argmin)", inline_reductions_test}, + {"inline reductions test (argmax)", inline_reductions_test}, + {"argmax rfactor test (explicit, index first)", argmax_rfactor_test}, + {"argmax rfactor test (tuple, index first)", argmax_rfactor_test}, + {"argmax rfactor test (explicit, value first)", argmax_rfactor_test}, + {"argmax rfactor test (tuple, value first)", argmax_rfactor_test}, {"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test}, {"rfactor bounds tests", rfactor_precise_bounds_test}, {"isnan max rfactor test (bitwise or)", isnan_max_rfactor_test},