@@ -265,9 +265,13 @@ bool extract_associative_op(const vector<Expr> &exprs, const vector<string> &op_
265265 x_parts, exprs, assoc_op);
266266}
267267
268- // Given dependencies of each tuple element, compute the set of subgraphs:
269- // all vertices that are reachable from a given vertex. If a subgraph is fully
270- // contained in another subgraph, remove it from the final output.
268+ bool is_subset_of (const std::set<int > &a, const std::set<int > &b) {
269+ return std::includes (b.begin (), b.end (), a.begin (), a.end ());
270+ }
271+
272+ // Compute the dependency subgraphs for a tuple reduction. First closes the
273+ // dependency relation transitively, then returns only the earliest (by index)
274+ // maximal dependency sets, clearing any set contained in a dominating one.
271275vector<set<int >> compute_subgraphs (vector<set<int >> dependencies) {
272276 // Compute the transitive closure using Warshall's algorithm.
273277 for (size_t k = 0 ; k < dependencies.size (); ++k) {
@@ -280,31 +284,25 @@ vector<set<int>> compute_subgraphs(vector<set<int>> dependencies) {
280284 }
281285 }
282286
287+ // Keep only maximal dependency sets. A set is removed if another
288+ // set strictly contains it or is identical but has a lower index.
283289 vector<set<int >> subgraphs (dependencies.size ());
284290 for (size_t i = 0 ; i < dependencies.size (); ++i) {
285- // Check if the current subgraph is a subset of another
286- const auto ¤t = dependencies[i];
287- if (current.empty ()) {
291+ if (dependencies[i].empty ()) {
288292 continue ;
289293 }
290- bool should_remove = false ;
294+ bool is_maximal = true ;
291295 for (size_t j = 0 ; j < dependencies.size (); ++j) {
292- const auto &other = dependencies[j];
293- if ((i == j) || (current.size () > other.size ()) || (j < i && subgraphs[i].empty ())) {
294- continue ;
295- }
296- vector<int > diff;
297- // Compute the vertices in the current set that are not contained in the other
298- std::set_difference (current.begin (), current.end (), other.begin (), other.end (),
299- std::inserter (diff, diff.begin ()));
300- if (diff.empty ()) {
301- // 'current' is fully contained in 'other'
302- should_remove = true ;
296+ const bool can_dominate =
297+ (dependencies[j].size () > dependencies[i].size ()) ||
298+ (dependencies[j].size () == dependencies[i].size () && j < i);
299+ if (can_dominate && is_subset_of (dependencies[i], dependencies[j])) {
300+ is_maximal = false ;
303301 break ;
304302 }
305303 }
306- if (!should_remove ) {
307- subgraphs[i] = current ;
304+ if (is_maximal ) {
305+ subgraphs[i] = dependencies[i] ;
308306 }
309307 }
310308 return subgraphs;
0 commit comments