@@ -1368,7 +1368,7 @@ class GPUTilingDedup {
13681368 }
13691369
13701370 /* * Generate Halide GPU schedules. */
1371- void apply (AutoSchedule &sched) {
1371+ void apply (AutoSchedule &sched, const Expr& parallelism ) {
13721372 if (!ordering.empty () && !is_initial_order) {
13731373 std::set<std::string> var_list;
13741374 for (const auto &v : ordering) {
@@ -1396,7 +1396,7 @@ class GPUTilingDedup {
13961396 }
13971397
13981398 GPUTileHelper helper{f, stage_num};
1399- Expr threads_budget = max_n_threads;
1399+ Expr threads_budget = min (parallelism, max_n_threads) ;
14001400
14011401 // Maximize GPU thread occupancy with the grid-stride loop.
14021402 //
@@ -1423,22 +1423,22 @@ class GPUTilingDedup {
14231423
14241424 const auto &[var, entry] = *iter;
14251425
1426- const bool should_unroll = can_prove (entry.factor <= 1 );
1427- if (should_unroll) {
1428- // Skip thread size of 1.
1429- continue ;
1430- }
1426+ // const bool should_unroll = can_prove(entry.factor <= 1);
1427+ // if (should_unroll) {
1428+ // // Skip thread size of 1.
1429+ // continue;
1430+ // }
14311431
14321432 split_info new_entry{entry};
1433- new_entry.factor = 1 ;
1433+ new_entry.factor = simplify ( min (threads_budget, entry. factor )) ;
14341434
14351435 const bool can_split = helper.try_split (new_entry);
14361436 if (!can_split) {
14371437 // If more than 3 gpu_blocks are defined, mark the current loop as the for-loop.
14381438 parallelize.erase (iter);
14391439 continue ;
14401440 }
1441- threads_budget = simplify (max (threads_budget / entry .factor , 1 ));
1441+ threads_budget = simplify (max (threads_budget / new_entry .factor , 1 ));
14421442 }
14431443
14441444 helper.commit (sched, is_compute_at);
@@ -2210,7 +2210,7 @@ Partitioner::find_best_tile_config(const Group &g) {
22102210 Group no_tile = g;
22112211 no_tile.tile_sizes = no_tile_config;
22122212
2213- bool show_analysis = false ;
2213+ constexpr bool show_analysis = false ;
22142214 GroupAnalysis no_tile_analysis = analyze_group (no_tile, show_analysis);
22152215
22162216 GroupAnalysis best_analysis = no_tile_analysis;
@@ -2233,7 +2233,7 @@ Partitioner::find_best_tile_config(const Group &g) {
22332233 Expr benefit = estimate_benefit (best_analysis, new_analysis,
22342234 no_redundant_work, true );
22352235
2236- if (show_analysis) {
2236+ if constexpr (show_analysis) {
22372237 debug (0 ) << " Benefit relative to not tiling:" << benefit << " \n " ;
22382238 debug (0 ) << " Best analysis:" << new_analysis;
22392239 debug (0 ) << " No tile analysis:" << no_tile_analysis;
@@ -3439,7 +3439,8 @@ void Partitioner::generate_group_cpu_schedule(
34393439 }
34403440 }
34413441 if (arch_params.is_gpu_schedule ) {
3442- auto parallelized_split = gpu_tiling.can_parallelize (v, iter->second );
3442+ const Expr gpu_threads = simplify (min (iter->second , arch_params.parallelism / def_par));
3443+ auto parallelized_split = gpu_tiling.can_parallelize (v, gpu_threads);
34433444 if (parallelized_split) {
34443445 auto split_vars = *parallelized_split;
34453446 inner_dims.emplace_back (split_vars.inner );
@@ -3463,7 +3464,7 @@ void Partitioner::generate_group_cpu_schedule(
34633464 }
34643465
34653466 if (arch_params.is_gpu_schedule ) {
3466- gpu_tiling.apply (sched);
3467+ gpu_tiling.apply (sched, arch_params. parallelism );
34673468 }
34683469
34693470 // Find the level at which group members will be computed.
@@ -3552,7 +3553,7 @@ void Partitioner::generate_group_cpu_schedule(
35523553 mem_rvars, mem_estimates, sched, gpu_tiling2);
35533554
35543555 if (arch_params.is_gpu_schedule ) {
3555- gpu_tiling2.apply (sched);
3556+ gpu_tiling2.apply (sched, arch_params. parallelism );
35563557 }
35573558 }
35583559}
0 commit comments