Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
19cd198
Initial plan
Copilot Mar 18, 2026
92aaedf
Fix example 10 hang: reset locks tensor at start of run_experiment
Copilot Mar 18, 2026
92b9696
Add benchmark sweep script and re-enable example 10 in CI
Copilot Mar 18, 2026
eeb72a1
Add benchmark-sweep CI workflow with PR comment reporting
Copilot Mar 18, 2026
e022124
Update benchmark-sweep workflow: add pull_request trigger to auto-run…
Copilot Mar 18, 2026
bd4c3a0
Revert benchmark-sweep workflow and report generator script
Copilot Mar 18, 2026
efe486e
Fix three bugs in example 10: preamble fn, acquire/release atomics, m…
Copilot Mar 18, 2026
eadb1e0
Fix Bug 3: use int32 locks + .wt data store + atomic_xchg/atomic_cas …
Copilot Mar 18, 2026
ba01295
Remove accidentally committed temp files
Copilot Mar 18, 2026
484b71c
Address code review: add preamble barrier comments and optimize spin-…
Copilot Mar 18, 2026
75d8461
Remove sweep_benchmark.py from example 10
Copilot Mar 19, 2026
e8969d5
Add bf16 roofline plot for example 10 benchmark sweep (8x MI325X, 8 r…
Copilot Mar 19, 2026
f7a1efa
Remove roofline PNG, update CI perf threshold to 1440 TFLOPS (~80% of…
Copilot Mar 19, 2026
b6e66d2
Add torchrun detection to benchmark.py to fix CI failure (same patter…
Copilot Mar 19, 2026
3dac33d
Merge branch 'main' into copilot/fix-example-10-hang
mawad-amd Mar 19, 2026
4b476d4
Fix validation failure: add dedicated validation run before validate_…
Copilot Mar 19, 2026
5549cc3
Merge branch 'main' into copilot/fix-example-10-hang
mawad-amd Mar 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/workflows/iris-performance-regression-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ jobs:
matrix:
# Performance baselines measured on AMD Instinct MI325X (8 GPUs)
include:
# Disabled https://github.com/ROCm/iris/issues/238
#- example_name: "GEMM All-Scatter WG Specialization"
# example_path: "10_gemm_all_scatter_wg_specialization"
# tflops_threshold: 1600 # Actual: ~2182 TFLOPs
# benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256"
- example_name: "GEMM All-Scatter WG Specialization"
example_path: "10_gemm_all_scatter_wg_specialization"
tflops_threshold: 1440 # Actual: ~1802 TFLOPs (80% regression threshold)
benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256"

- example_name: "GEMM All-Scatter"
example_path: "07_gemm_all_scatter"
Expand Down
52 changes: 39 additions & 13 deletions examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
Expand Down Expand Up @@ -132,7 +134,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

bias = None

Expand All @@ -153,13 +155,18 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
# Barrier 1: ensure all ranks finish previous iteration before clearing locks
shmem.barrier()
locks.zero_()
# Barrier 2: ensure all ranks see zeroed locks before any rank starts the kernel
shmem.barrier()

def run_experiment():
nonlocal local_C
nonlocal global_C
nonlocal kernel_timing

shmem.barrier()

if args["trace_tiles"]:
timestamps.reset()
shmem.barrier()
Expand Down Expand Up @@ -215,6 +222,16 @@ def run_experiment():
kernel_timing[k]["experiments"] = 0

if args["validate"]:
# Run a dedicated validation kernel to ensure all cross-GPU writes are fully
# propagated before checking results. The warmup above may leave some
# iris.put stores in-flight on the xGMI interconnect; the extra
# preamble + run + barrier cycle guarantees all ranks have flushed their
# GPU caches and that rank-0 sees every scattered tile before we call
# validate_gemm.
preamble()
run_experiment()
shmem.barrier()

shmem.info("Validating...")
matmul.set_debug(True)
# Validate global result
Expand All @@ -241,7 +258,7 @@ def run_experiment():
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
algo_string = "all_scatter"
shmem.info(
Expand Down Expand Up @@ -275,15 +292,24 @@ def run_experiment():
def main():
args = parse_args()

num_ranks = args["num_ranks"]

init_url = "tcp://127.0.0.1:29500"
mp.spawn(
fn=_worker,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)
# Check if running with torchrun (detected by environment variables)
if "RANK" in os.environ and "LOCAL_RANK" in os.environ:
# torchrun handles process spawning, so call _worker directly
print("Detected torchrun execution mode")
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500")
_worker(rank, world_size, f"tcp://{init_url}", args)
else:
# Use multiprocessing spawn for backward compatibility
num_ranks = args["num_ranks"]
init_url = "tcp://127.0.0.1:29500"
mp.spawn(
fn=_worker,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def persistent_gemm_all_scatter_wg_specialization(
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu")

else: # pid >= GEMM_SMS
COMM_SMS = NUM_SMS - GEMM_SMS
Expand All @@ -163,8 +162,11 @@ def persistent_gemm_all_scatter_wg_specialization(
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# End: masks/offset calculations.

# Spin-wait: first check with a cheap volatile load, then acquire-CAS to
# ensure memory ordering once the lock is observed set.
while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
pass
tl.atomic_cas(locks + tile_id, 1, 1, sem="acquire", scope="gpu")

for remote_rank in range(world_size):
if remote_rank != cur_rank:
Expand Down
Loading