diff --git a/.github/workflows/iris-performance-regression-test.yml b/.github/workflows/iris-performance-regression-test.yml index ebde87df..fc7081b8 100644 --- a/.github/workflows/iris-performance-regression-test.yml +++ b/.github/workflows/iris-performance-regression-test.yml @@ -24,11 +24,10 @@ jobs: matrix: # Performance baselines measured on AMD Instinct MI325X (8 GPUs) include: - # Disabled https://github.com/ROCm/iris/issues/238 - #- example_name: "GEMM All-Scatter WG Specialization" - # example_path: "10_gemm_all_scatter_wg_specialization" - # tflops_threshold: 1600 # Actual: ~2182 TFLOPs - # benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256" + - example_name: "GEMM All-Scatter WG Specialization" + example_path: "10_gemm_all_scatter_wg_specialization" + tflops_threshold: 1440 # Actual: ~1802 TFLOPs (80% regression threshold) + benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256" - example_name: "GEMM All-Scatter" example_path: "07_gemm_all_scatter" diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 655c892f..910ebdd6 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +import os + import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -132,7 +134,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) bias = None @@ -153,13 +155,18 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + def preamble(): + # Barrier 1: ensure all ranks finish previous iteration before clearing locks + shmem.barrier() + locks.zero_() + # Barrier 2: ensure all ranks see zeroed locks before any rank starts the kernel + shmem.barrier() + def run_experiment(): nonlocal local_C nonlocal global_C nonlocal kernel_timing - shmem.barrier() - if args["trace_tiles"]: timestamps.reset() shmem.barrier() @@ -215,6 +222,16 @@ def run_experiment(): kernel_timing[k]["experiments"] = 0 if args["validate"]: + # Run a dedicated validation kernel to ensure all cross-GPU writes are fully + # propagated before checking results. The warmup above may leave some + # iris.put stores in-flight on the xGMI interconnect; the extra + # preamble + run + barrier cycle guarantees all ranks have flushed their + # GPU caches and that rank-0 sees every scattered tile before we call + # validate_gemm. + preamble() + run_experiment() + shmem.barrier() + shmem.info("Validating...") matmul.set_debug(True) # Validate global result @@ -241,7 +258,7 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) triton_tflops = perf(triton_ms) algo_string = "all_scatter" shmem.info( @@ -275,15 +292,24 @@ def run_experiment(): def main(): args = parse_args() - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ and "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500") + _worker(rank, world_size, f"tcp://{init_url}", args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 4d9c2825..643e84f9 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -140,8 +140,7 @@ def persistent_gemm_all_scatter_wg_specialization( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") - tl.debug_barrier() - tl.store(locks + tile_id, 1, cache_modifier=".wt") + tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu") else: # pid >= GEMM_SMS COMM_SMS = NUM_SMS - GEMM_SMS @@ -163,8 +162,11 @@ def persistent_gemm_all_scatter_wg_specialization( global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global # End: masks/offset calculations. + # Spin-wait: first check with a cheap volatile load, then acquire-CAS to + # ensure memory ordering once the lock is observed set. while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: pass + tl.atomic_cas(locks + tile_id, 1, 1, sem="acquire", scope="gpu") for remote_rank in range(world_size): if remote_rank != cur_rank: