Skip to content

[Feature] Block-scaled GEMM support for MXFP8 on Blackwell#1945

Open
Rachmanino wants to merge 7 commits intotile-ai:mainfrom
Rachmanino:wt/mxfp8
Open

[Feature] Block-scaled GEMM support for MXFP8 on Blackwell#1945
Rachmanino wants to merge 7 commits intotile-ai:mainfrom
Rachmanino:wt/mxfp8

Conversation

@Rachmanino
Copy link
Copy Markdown
Collaborator

@Rachmanino Rachmanino commented Mar 18, 2026

  • Introduced a new Python file for MXFP8 block-scaled GEMM implementation.
  • Added intrinsic definitions for block-scaled MMA, UTCCP, and scale factor operations in the C++ backend.
  • Implemented corresponding TileLang functions for block-scaled GEMM and scale factor handling.
  • Enhanced CUDA code generation to support new block-scaled instructions.
  • Updated documentation to reflect new functionalities and usage examples.

This commit enhances the performance of GEMM operations on the SM100 architecture by leveraging block scaling and optimized memory access patterns.

Summary by CodeRabbit

  • New Features

    • Block‑scaled MXFP8 GEMM for SM100 with example scripts, verification, and benchmarking.
  • API / Intrinsics

    • Public interfaces to emit/run block‑scaled GEMM and produce TMEM layouts.
    • New device intrinsics and frontend wrappers: block‑scaled MMA, shared→tmem copy (UTCCP), scale‑factor warp‑transpose, SMEM descriptor builder, and before/after thread‑sync primitives.
  • Language / Exports

    • TileLang exports expanded to expose block‑scaled GEMM paths and related helpers.

- Introduced a new Python file for MXFP8 block-scaled GEMM implementation.
- Added intrinsic definitions for block-scaled MMA, UTCCP, and scale factor operations in the C++ backend.
- Implemented corresponding TileLang functions for block-scaled GEMM and scale factor handling.
- Enhanced CUDA code generation to support new block-scaled instructions.
- Updated documentation to reflect new functionalities and usage examples.

This commit enhances the performance of GEMM operations on the SM100 architecture by leveraging block scaling and optimized memory access patterns.
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 18, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds end‑to‑end block‑scaled MXFP8 GEMM support for TCGEN5/SM100: new TileLang frontend APIs and intrinsics, TileLang lowering and macro emission for block‑scaled MMA, TCGEN5 builtins and device helpers (UTCCP copy, SF warp‑transpose, sync hooks), CUDA codegen/templates, TMEM copy path, and examples with validation and benchmark. (50 words)

Changes

Cohort / File(s) Summary
Examples
examples/gemm_sm100/gemm_mxfp8_blockscaled.py, examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py
New TileLang/JIT examples implementing MXFP8 block‑scaled GEMM kernels, SF pack/unpack utilities, reference implementations, validation and benchmarking harnesses.
TL Builtins / Meta
src/op/builtin.cc, src/op/builtin.h, src/op/tcgen5_meta.h
Added TL builtins for block‑scaled MMA, CP, SF warp‑transpose and before/after thread‑sync; new GetTCGEN5BlockScaledInstrDesc to encode block‑scaled instruction descriptors (includes SF IDs).
Gemm FFI / Reflection
src/op/gemm_py.cc, src/op/gemm_py.h
Extended GemmPy to carry SFA/SFB regions and SF IDs; added tl.get_tcgen5_blockscaled_instr_desc binding for descriptor construction.
CUDA Codegen / Templates
src/target/codegen_cuda.cc, src/tl_templates/cuda/instruction/tcgen05mma.h, src/tl_templates/cuda/tcgen_05.h
Emit and implement new tcgen05 intrinsics: mma_blockscaled_ss, cp, sf_warp_transpose and thread‑sync hooks; device helpers for SF copy/warp‑transpose and SMEM descriptor builder; codegen mappings and replacers updated.
TileLang Intrinsics & MacroGen
tilelang/intrinsics/tcgen05_macro_generator.py
Added tcgen05mma_blockscaled emission path and helper get_tcgen5_blockscaled_instr_desc; wiring for TMEM SF pointers and per‑invocation SF IDs.
Frontend APIs / Builtins
tilelang/language/gemm_op.py, tilelang/language/builtin.py
Added public APIs: blockscaled_gemm, make_blockscaled_gemm_layout; added builtin wrappers: tcgen05_before_thread_sync, tcgen05_after_thread_sync, tcgen05_cp, sf_warp_transpose, make_sf_smem_desc.
Language Exports & IR / TIR Wrappers
tilelang/language/__init__.py, tilelang/language/ast/ir.py, tilelang/language/tir/ir.py, tilelang/language/tir/op.py
Exported new PTX/TIR wrappers and intrinsics: ptx_tcgen05_mma_blockscaled_ss, ptx_tcgen05_cp, ptx_tcgen05_sf_warp_transpose and corresponding TIR op wrappers and dtype‑forward bindings.
TileOp / Gemm Lowering
tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_tcgen05.py, tilelang/tileop/gemm/gemm_py.cc
Added SF properties/fields and new TCGEN5 block‑scaled lowering path (_lower_blockscaled), enforced 1x1 warp partition for block‑scaled GEMM, and new primfuncs for block‑scaled lowering.
Copy / TMEM Path
src/op/copy.cc, src/op/copy.h
Introduced CopyInst::kTMemCp (shared→tmem copy), CheckTMemCp helper, and lowering that emits per‑chunk ptx_tcgen05_cp calls for UTCCP uploads.
Language Exports Broadening
tilelang/language/__init__.py, tilelang/language/builtin.py
Expanded public exports to include blockscaled_gemm and the new TCGEN05 helper intrinsics and utilities.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant PyAPI as blockscaled_gemm (Python API)
    participant Emitter as TensorCoreIntrinEmitter
    participant MacroGen as tcgen05_macro_generator
    participant Codegen as CUDA Codegen
    participant Device as GPU Device

    User->>PyAPI: call blockscaled_gemm(A,B,C,SFA,SFB,...)
    PyAPI->>Emitter: configure layouts, TMEM for C, and SF regions
    Emitter->>MacroGen: request tcgen05mma_blockscaled emission (with SF IDs)
    MacroGen->>MacroGen: build blockscaled instr desc, resolve TMEM SF pointers
    MacroGen->>Codegen: emit ptx_tcgen05_mma_blockscaled_ss / ptx_tcgen05_cp / ptx_tcgen05_sf_warp_transpose / sync calls
    Codegen->>Device: generate tcgen05.mma.cta_group / tcgen05.cp / sf_transpose / sync intrinsics
    Device->>Device: copy SFs to TMEM, warp‑transpose SFs, perform block‑scaled MMA, write TMEM→DRAM
    Device-->>User: results written to C
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • chengyupku

"I'm a rabbit in the warp and lane,
I pack scale‑factors in a playful train,
TMEM hums while SMEM twirls the transpose art,
Block‑scaled GEMM hops — precision plays its part,
Hooray for MXFP8 — a furry, speedy start!" 🐇✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main feature: block-scaled GEMM support for MXFP8 on Blackwell, which is the primary change across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Around line 35-37: The current calculations use floor integer division which
can produce 0 or drop tail chunks (e.g., k_iters, sf_load_period,
sfa_num_chunks, sfb_num_chunks) causing divide-by-zero or missing scale-factor
data; change these to use ceiling division (or explicitly clamp sf_load_period =
max(1, ceil_div(...))) and compute sfa_num_chunks/sfb_num_chunks with ceil
division so tail chunks are included, and add a short comment documenting the
required assumptions for sf_granularity_k, block_K, block_M, and block_N to make
the behavior explicit (update the code around k_iters, sf_load_period, and the
sfa_num_chunks/sfb_num_chunks calculations; same fix for the analogous block at
lines 61-64).

In `@src/op/tcgen5_meta.h`:
- Around line 218-265: Summary: a_sf_id and b_sf_id are packed into 2-bit fields
and currently can be silently truncated; validate them before packing in
tl.get_tcgen5_blockscaled_instr_desc. Add explicit checks (e.g., ICHECK(a_sf_id
>= 0 && a_sf_id < 4) and ICHECK(b_sf_id >= 0 && b_sf_id < 4)) near the top of
the function that builds the descriptor (the same scope that defines
encode_mxfp_dtype, set_bits, and computes desc) so invalid values fail fast;
then continue to cast to uint32_t when calling
set_bits(static_cast<uint32_t>(a_sf_id), 29, 2) and
set_bits(static_cast<uint32_t>(b_sf_id), 4, 2).

In `@src/target/codegen_cuda.cc`:
- Around line 2541-2543: The replacer currently emits
"tcgen05mma_blockscaled_ws_ss" when enable_ws is true, but that WS helper
doesn't exist yet; update the registration in replacer.register_rule for
"(tcgen05_name)" to avoid emitting the WS variant until implemented—either
always emit "tcgen05mma_blockscaled_ss" regardless of enable_ws, or add an
explicit guard (e.g., if (enable_ws) { /* TODO: skip or fall back */ } else {
replacer.register_rule(... "tcgen05mma_blockscaled_ss"); }) so the code uses
"tcgen05mma_blockscaled_ss" and never references "tcgen05mma_blockscaled_ws_ss"
until the WS helper/header is added.

In `@src/tl_templates/cuda/tcgen_05.h`:
- Around line 72-82: The in-place warp transpose in tcgen05_sf_warp_transpose
performs a pre-store __syncwarp() but lacks a post-store barrier, so
shared-memory writes may be visible partially to a subsequent UTCCP consumer;
add a second __syncwarp() immediately after the store loop (after the lines that
write smem_ptr[lane * 4 + ...]) to ensure all threads have completed their
stores before returning, or alternatively document/require the caller to perform
this post-store synchronization.

In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 720-722: The packing of SF IDs into runtime_instr_desc (using
base_instr_desc | (_sf_a << 29) | (_sf_b << 4)) must mask and validate sf_a_id
and sf_b_id: when creating _sf_a/_sf_b (and before shifting) apply a mask like
(& 0x3) so only the bottom 2 bits are OR-ed, and if sf_a_id or sf_b_id are
constant ints reject/raise if they are outside the allowed range 0..3; update
the logic around the tvm.tir.const creation and the variables _sf_a, _sf_b to
both validate constants and mask non-constants prior to shifting to avoid
spilling into neighboring descriptor fields.
- Around line 599-608: The code currently ignores the returned WS and 2CTA flags
from get_tcgen5_mma_meta (meta -> _enable_ws, enable_2cta) and unconditionally
forces enable_ws=0, which can mis-lower shapes that require WS or 2CTA; update
the logic in tcgen05_macro_generator.py where meta is unpacked
(get_tcgen5_mma_meta and atom_m/atom_n/atom_k/_enable_ws/enable_2cta) to
validate that _enable_ws==0 and enable_2cta==0 before proceeding and otherwise
raise a ValueError (with a clear message including M/N/K/a_dtype/accum_dtype) so
the method fails fast for unsupported WS or 2CTA configurations instead of
silently forcing cta_group::1.
- Around line 681-694: The code currently strips BufferRegion operands to their
underlying buffer data and then sets sfa_offset/sfb_offset to 0, causing
subregions to lose their base offsets; update the SFA_tmem/SFB_tmem handling so
that when the operand is a BufferRegion you capture both the buffer data and its
base offset (e.g. set sfa_data = SFA_tmem.buffer.data and sfa_offset =
SFA_tmem.offset) and when it is a Buffer set sfa_data = SFA_tmem.data and
sfa_offset = 0 (do the analogous change for SFB_tmem -> sfb_data and
sfb_offset); ensure later code uses these sfa_offset/sfb_offset values instead
of hard-coding 0.

In `@tilelang/language/builtin.py`:
- Around line 1040-1048: The code in tcgen05_cp drops the TMEM slice offset by
replacing BufferLoad/BufferRegion destinations with only buffer.data; preserve
the destination's region/indices when computing tmem_ptr so the tmem_col_offset
is applied correctly. Update the branches handling BufferLoad and BufferRegion
to compute and incorporate the region index/offset (from the BufferLoad indices
or BufferRegion.region/offset fields) into tmem_ptr (or augment tmem_col_offset)
before calling tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_utccp"),
smem_ptr, tmem_ptr, tmem_col_offset) so sliced destinations write to the correct
TMEM column. Ensure you reference BufferLoad, BufferRegion, tmem_dst, tmem_ptr
and tmem_col_offset when making the change.

In `@tilelang/language/gemm_op.py`:
- Around line 295-302: Before constructing the TensorCoreIntrinEmitter, validate
that A and B FP8 formats match: read a_dtype = str(A_region.buffer.dtype) and
b_dtype = str(B_region.buffer.dtype) and if they are different FP8 encodings
(e.g., E4M3 vs E5M2) raise/return a clear error instead of proceeding; ensure
you either pass the real b_dtype into TensorCoreIntrinEmitter (b_dtype=b_dtype)
or abort when mixed FP8 formats are detected. This change touches the emitter
construction site (TensorCoreIntrinEmitter) and the code paths that read
A_region.buffer.dtype and B_region.buffer.dtype so the block-scaled descriptor
is not silently built with a single dtype for both operands.
- Around line 327-331: The return annotation "Layout" in
make_blockscaled_gemm_layout is not imported and triggers Pyflakes F821; add an
import for Layout (e.g., from tilelang.layout import Layout) at module level or,
if you want to avoid runtime imports, wrap it in a TYPE_CHECKING block (from
typing import TYPE_CHECKING; if TYPE_CHECKING: from tilelang.layout import
Layout) so the symbol exists for linting while preserving runtime behavior.

In `@tilelang/language/tir/op.py`:
- Around line 1309-1323: The ptx_tcgen05_utccp intrinsic currently names its
first parameter smem_desc but lowering expects a raw shared-memory pointer;
rename the parameter and its docstring from smem_desc to smem_ptr in the
ptx_tcgen05_utccp function (and update any callers/docs) so the frontend
contract matches the lowering, or alternatively add a distinct intrinsic that
explicitly accepts an smem descriptor if you need both forms; ensure the
call_intrin invocation still passes the pointer value and update the docstring
to describe smem_ptr as a raw shared-memory pointer.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c73147b3-3b03-42ed-8ea2-9e21ff0ce44b

📥 Commits

Reviewing files that changed from the base of the PR and between f8dc61c and a1bb0ee.

📒 Files selected for processing (15)
  • examples/gemm_sm100/gemm_mxfp8_blockscaled.py
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/gemm_py.cc
  • src/op/tcgen5_meta.h
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/instruction/tcgen05mma.h
  • src/tl_templates/cuda/tcgen_05.h
  • tilelang/intrinsics/tcgen05_macro_generator.py
  • tilelang/language/__init__.py
  • tilelang/language/ast/ir.py
  • tilelang/language/builtin.py
  • tilelang/language/gemm_op.py
  • tilelang/language/tir/ir.py
  • tilelang/language/tir/op.py

Comment on lines +35 to +37
k_iters = T.ceildiv(K, block_K)
# 4 packed E8M0 per uint32 → load every 4 stages
sf_load_period = sf_granularity_k * 4 // block_K
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These floor divisions break the example outside the current happy path.

sf_load_period silently floors and can become 0, and sfa_num_chunks/sfb_num_chunks truncate tail chunks. Changing block_K, block_M, or block_N away from the current multiples will either divide by zero or skip scale-factor data entirely.

At least make the assumptions explicit
         k_iters = T.ceildiv(K, block_K)
         # 4 packed E8M0 per uint32 → load every 4 stages
         sf_load_period = sf_granularity_k * 4 // block_K
+        assert sf_load_period > 0 and (sf_granularity_k * 4) % block_K == 0, (
+            "block_K must evenly divide 4 * sf_granularity_k"
+        )
@@
-        sfa_num_chunks = block_M // 128  # number of 128-element UTCCP chunks
-        sfb_num_chunks = block_N // 128
+        assert block_M % 128 == 0 and block_N % 128 == 0, (
+            "This example currently requires block_M and block_N to be multiples of 128"
+        )
+        sfa_num_chunks = block_M // 128  # number of 128-element UTCCP chunks
+        sfb_num_chunks = block_N // 128

Also applies to: 61-64

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 35 - 37, The
current calculations use floor integer division which can produce 0 or drop tail
chunks (e.g., k_iters, sf_load_period, sfa_num_chunks, sfb_num_chunks) causing
divide-by-zero or missing scale-factor data; change these to use ceiling
division (or explicitly clamp sf_load_period = max(1, ceil_div(...))) and
compute sfa_num_chunks/sfb_num_chunks with ceil division so tail chunks are
included, and add a short comment documenting the required assumptions for
sf_granularity_k, block_K, block_M, and block_N to make the behavior explicit
(update the code around k_iters, sf_load_period, and the
sfa_num_chunks/sfb_num_chunks calculations; same fix for the analogous block at
lines 61-64).

Comment thread src/op/tcgen5_meta.h
Comment on lines +218 to +265
int a_sf_id, int b_sf_id) {
ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16";
ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8";
ICHECK(scale_in_a == 1 || scale_in_a == -1);
ICHECK(scale_in_b == 1 || scale_in_b == -1);

// a_format / b_format for MXF8F6F4: E4M3=0, E5M2=1
auto encode_mxfp_dtype = [&](DataType dtype) -> uint32_t {
if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() ||
dtype.is_float8_e4m3()) {
return 0u; // E4M3
} else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) {
return 1u; // E5M2
}
LOG(FATAL) << "Unsupported dtype for block-scaled descriptor: " << dtype;
return 0u;
};

auto set_bits = [](uint32_t value, int start, int width) -> uint32_t {
uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1);
return (value & mask) << start;
};

uint32_t a_format = encode_mxfp_dtype(ab_dtype);
uint32_t b_format = a_format;
uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u;
uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u;
uint32_t a_major = a_is_k_major ? 0u : 1u;
uint32_t b_major = b_is_k_major ? 0u : 1u;
uint32_t n_dim = static_cast<uint32_t>(atom_n >> 3);
uint32_t m_dim = static_cast<uint32_t>(atom_m >> 4);

uint32_t desc = 0;
desc |= set_bits(0, 0, 2); // sparse_id2
desc |= set_bits(0, 2, 1); // sparse_flag
// bit 3 reserved
desc |= set_bits(static_cast<uint32_t>(b_sf_id), 4, 2); // b_sf_id
// bit 6 reserved
desc |= set_bits(a_format, 7, 3); // a_format
desc |= set_bits(b_format, 10, 3); // b_format
desc |= set_bits(a_neg, 13, 1); // a_negate
desc |= set_bits(b_neg, 14, 1); // b_negate
desc |= set_bits(a_major, 15, 1); // a_major
desc |= set_bits(b_major, 16, 1); // b_major
desc |= set_bits(n_dim, 17, 6); // n_dim
desc |= set_bits(1, 23, 1); // scale_format = 1 (E8M0)
desc |= set_bits(m_dim, 24, 5); // m_dim
desc |= set_bits(static_cast<uint32_t>(a_sf_id), 29, 2); // a_sf_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate a_sf_id and b_sf_id before packing them.

Lines 254-265 write both SF IDs through a 2-bit field, so invalid values are silently truncated. For example, -1 becomes 3, and 4 aliases back to 0. Because this helper is exposed via tl.get_tcgen5_blockscaled_instr_desc, bad frontend inputs currently compile into the wrong descriptor instead of failing fast.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/tcgen5_meta.h` around lines 218 - 265, Summary: a_sf_id and b_sf_id
are packed into 2-bit fields and currently can be silently truncated; validate
them before packing in tl.get_tcgen5_blockscaled_instr_desc. Add explicit checks
(e.g., ICHECK(a_sf_id >= 0 && a_sf_id < 4) and ICHECK(b_sf_id >= 0 && b_sf_id <
4)) near the top of the function that builds the descriptor (the same scope that
defines encode_mxfp_dtype, set_bits, and computes desc) so invalid values fail
fast; then continue to cast to uint32_t when calling
set_bits(static_cast<uint32_t>(a_sf_id), 29, 2) and
set_bits(static_cast<uint32_t>(b_sf_id), 4, 2).

Comment on lines +2541 to +2543
replacer.register_rule(
"(tcgen05_name)",
enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't emit the WS block-scaled helper until it exists.

This branch can generate tl::tcgen05mma_blockscaled_ws_ss, but the new instruction header in this PR only defines tcgen05mma_blockscaled_ss. Any block-scaled kernel with enable_ws=true will fail during CUDA compilation.

Temporary guard until the WS variant is implemented
-    replacer.register_rule(
-        "(tcgen05_name)",
-        enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss");
+    ICHECK(!enable_ws)
+        << "WS block-scaled tcgen05mma emission is not implemented yet.";
+    replacer.register_rule("(tcgen05_name)", "tcgen05mma_blockscaled_ss");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
replacer.register_rule(
"(tcgen05_name)",
enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss");
ICHECK(!enable_ws)
<< "WS block-scaled tcgen05mma emission is not implemented yet.";
replacer.register_rule("(tcgen05_name)", "tcgen05mma_blockscaled_ss");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_cuda.cc` around lines 2541 - 2543, The replacer currently
emits "tcgen05mma_blockscaled_ws_ss" when enable_ws is true, but that WS helper
doesn't exist yet; update the registration in replacer.register_rule for
"(tcgen05_name)" to avoid emitting the WS variant until implemented—either
always emit "tcgen05mma_blockscaled_ss" regardless of enable_ws, or add an
explicit guard (e.g., if (enable_ws) { /* TODO: skip or fall back */ } else {
replacer.register_rule(... "tcgen05mma_blockscaled_ss"); }) so the code uses
"tcgen05mma_blockscaled_ss" and never references "tcgen05mma_blockscaled_ws_ss"
until the WS helper/header is added.

Comment on lines +72 to +82
TL_DEVICE void tcgen05_sf_warp_transpose(uint32_t *smem_ptr) {
const uint32_t lane = threadIdx.x % 32;
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++i)
values[i] = smem_ptr[(i ^ (lane >> 3)) * 32 + lane];
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++i)
smem_ptr[lane * 4 + (i ^ (lane >> 3))] = values[i];
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add a post-store warp sync to the in-place transpose.

Line 78 synchronizes before the writes, which protects the read phase, but nothing orders the shared-memory stores after Line 81 before the next consumer runs. If the elected lane issues UTCCP immediately after this helper returns, it can observe partially written transposed data. Add a second __syncwarp() after the store loop, or make the caller perform one explicitly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/tcgen_05.h` around lines 72 - 82, The in-place warp
transpose in tcgen05_sf_warp_transpose performs a pre-store __syncwarp() but
lacks a post-store barrier, so shared-memory writes may be visible partially to
a subsequent UTCCP consumer; add a second __syncwarp() immediately after the
store loop (after the lines that write smem_ptr[lane * 4 + ...]) to ensure all
threads have completed their stores before returning, or alternatively
document/require the caller to perform this post-store synchronization.

Comment on lines +599 to +608
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 5:
raise ValueError(
f"Unsupported TCGEN5MMA configuration for block-scaled: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}"
)
atom_m, atom_n, atom_k, _enable_ws, enable_2cta = (int(x) for x in meta)
# Block-scaled MMA (.block_scale) does NOT support .ws variant per PTX ISA.
# Force non-ws mode. atom_m=128 is required for cta_group::1 non-ws.
enable_ws = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject WS / 2CTA metadata before forcing the non-WS cta_group::1 form.

Lines 605-608 drop _enable_ws and enable_2cta and still emit the only variant this method knows how to generate. That will mis-lower any shape whose selected TCGEN5 meta requires WS or 2CTA support. Fail fast unless the returned meta already matches the supported non-WS cta_group::1 path.

🧰 Tools
🪛 Ruff (0.15.6)

[warning] 605-605: Unpacked variable atom_k is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


[warning] 605-605: Unpacked variable enable_2cta is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 599 - 608, The
code currently ignores the returned WS and 2CTA flags from get_tcgen5_mma_meta
(meta -> _enable_ws, enable_2cta) and unconditionally forces enable_ws=0, which
can mis-lower shapes that require WS or 2CTA; update the logic in
tcgen05_macro_generator.py where meta is unpacked (get_tcgen5_mma_meta and
atom_m/atom_n/atom_k/_enable_ws/enable_2cta) to validate that _enable_ws==0 and
enable_2cta==0 before proceeding and otherwise raise a ValueError (with a clear
message including M/N/K/a_dtype/accum_dtype) so the method fails fast for
unsupported WS or 2CTA configurations instead of silently forcing cta_group::1.

Comment on lines +720 to +722
_sf_a = tvm.tir.const(sf_a_id, "int32") if isinstance(sf_a_id, int) else sf_a_id
_sf_b = tvm.tir.const(sf_b_id, "int32") if isinstance(sf_b_id, int) else sf_b_id
runtime_instr_desc = base_instr_desc | (_sf_a << 29) | (_sf_b << 4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Mask or validate runtime SF IDs before OR-ing them into the descriptor.

Line 722 assumes sf_a_id and sf_b_id are already 2-bit values. If either one is larger than 3, the shift spills into neighboring descriptor fields instead of only updating the SF slots. Apply & 0x3 when packing, and still reject constant IDs outside [0, 3].

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 720 - 722, The
packing of SF IDs into runtime_instr_desc (using base_instr_desc | (_sf_a << 29)
| (_sf_b << 4)) must mask and validate sf_a_id and sf_b_id: when creating
_sf_a/_sf_b (and before shifting) apply a mask like (& 0x3) so only the bottom 2
bits are OR-ed, and if sf_a_id or sf_b_id are constant ints reject/raise if they
are outside the allowed range 0..3; update the logic around the tvm.tir.const
creation and the variables _sf_a, _sf_b to both validate constants and mask
non-constants prior to shifting to avoid spilling into neighboring descriptor
fields.

Comment thread tilelang/language/builtin.py Outdated
Comment on lines +1040 to +1048
if isinstance(tmem_dst, (tir.Buffer,)):
tmem_ptr = tmem_dst.data
elif isinstance(tmem_dst, BufferLoad):
tmem_ptr = tmem_dst.buffer.data
elif isinstance(tmem_dst, BufferRegion):
tmem_ptr = tmem_dst.buffer.data
else:
tmem_ptr = tmem_dst
return tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_utccp"), smem_ptr, tmem_ptr, tmem_col_offset)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't discard the TMEM slice offset in tcgen05_cp.

For BufferLoad/BufferRegion destinations this helper drops the region indices and forwards only buffer.data. The codegen path then applies only tmem_col_offset, so a sliced destination still writes from TMEM column 0 unless the caller manually duplicates the offset.

Safe short-term fix
-    if isinstance(tmem_dst, (tir.Buffer,)):
+    if isinstance(tmem_dst, tir.Buffer):
         tmem_ptr = tmem_dst.data
-    elif isinstance(tmem_dst, BufferLoad):
-        tmem_ptr = tmem_dst.buffer.data
-    elif isinstance(tmem_dst, BufferRegion):
-        tmem_ptr = tmem_dst.buffer.data
+    elif isinstance(tmem_dst, (BufferLoad, BufferRegion)):
+        raise ValueError(
+            "tcgen05_cp currently expects the base TMEM buffer; "
+            "use tmem_col_offset to select destination columns."
+        )
     else:
         tmem_ptr = tmem_dst
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/builtin.py` around lines 1040 - 1048, The code in
tcgen05_cp drops the TMEM slice offset by replacing BufferLoad/BufferRegion
destinations with only buffer.data; preserve the destination's region/indices
when computing tmem_ptr so the tmem_col_offset is applied correctly. Update the
branches handling BufferLoad and BufferRegion to compute and incorporate the
region index/offset (from the BufferLoad indices or BufferRegion.region/offset
fields) into tmem_ptr (or augment tmem_col_offset) before calling
tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_utccp"), smem_ptr,
tmem_ptr, tmem_col_offset) so sliced destinations write to the correct TMEM
column. Ensure you reference BufferLoad, BufferRegion, tmem_dst, tmem_ptr and
tmem_col_offset when making the change.

Comment thread tilelang/language/gemm_op.py Outdated
Comment on lines +295 to +302
a_dtype = str(A_region.buffer.dtype)
accum_dtype = str(C_region.buffer.dtype)

# Create intrinsic emitter — for block-scaled, TCGEN5 always uses 1 warp group
emitter = TensorCoreIntrinEmitter(
a_dtype=a_dtype,
b_dtype=a_dtype,
accum_dtype=accum_dtype,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate A/B dtype compatibility before building the emitter.

This hard-codes b_dtype=a_dtype, so a mixed E4M3/E5M2 call will silently encode B as A's format. The block-scaled descriptor path currently mirrors a single FP8 dtype into both a_format and b_format, so this should be rejected explicitly instead of generating the wrong descriptor.

Minimal defensive fix
     a_dtype = str(A_region.buffer.dtype)
+    b_dtype = str(B_region.buffer.dtype)
     accum_dtype = str(C_region.buffer.dtype)
+    if b_dtype != a_dtype:
+        raise ValueError(
+            "blockscaled_gemm currently requires A and B to use the same FP8 dtype."
+        )
 
     # Create intrinsic emitter — for block-scaled, TCGEN5 always uses 1 warp group
     emitter = TensorCoreIntrinEmitter(
         a_dtype=a_dtype,
-        b_dtype=a_dtype,
+        b_dtype=b_dtype,
         accum_dtype=accum_dtype,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/gemm_op.py` around lines 295 - 302, Before constructing the
TensorCoreIntrinEmitter, validate that A and B FP8 formats match: read a_dtype =
str(A_region.buffer.dtype) and b_dtype = str(B_region.buffer.dtype) and if they
are different FP8 encodings (e.g., E4M3 vs E5M2) raise/return a clear error
instead of proceeding; ensure you either pass the real b_dtype into
TensorCoreIntrinEmitter (b_dtype=b_dtype) or abort when mixed FP8 formats are
detected. This change touches the emitter construction site
(TensorCoreIntrinEmitter) and the code paths that read A_region.buffer.dtype and
B_region.buffer.dtype so the block-scaled descriptor is not silently built with
a single dtype for both operands.

Comment on lines +327 to +331
def make_blockscaled_gemm_layout(
C: BufferLikeType,
A: BufferLikeType,
transpose_A: bool = False,
) -> "Layout":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file structure and locate the relevant lines
fd gemm_op.py

Repository: tile-ai/tilelang

Length of output: 88


🏁 Script executed:

# Check the imports and the function in question
cat -n tilelang/language/gemm_op.py | head -40

Repository: tile-ai/tilelang

Length of output: 1407


🏁 Script executed:

# Look at the lines around 327-331
sed -n '320,340p' tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 810


🏁 Script executed:

# Search for all imports related to Layout in the file
rg "import.*Layout|from.*Layout" tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Also check if Layout is imported anywhere or if it's being used as a string
rg "Layout" tilelang/language/gemm_op.py | head -20

Repository: tile-ai/tilelang

Length of output: 120


🏁 Script executed:

# Find where Layout is defined
fd -t f "\.py$" | xargs rg "^class Layout" | head -5

Repository: tile-ai/tilelang

Length of output: 105


🏁 Script executed:

# Check what's in tilelang.layout module
fd layout.py

Repository: tile-ai/tilelang

Length of output: 381


🏁 Script executed:

# Check if there's a TYPE_CHECKING import block in the file
rg "TYPE_CHECKING|tilelang.layout" tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 122


🏁 Script executed:

# Also check the full imports section to be thorough
sed -n '1,30p' tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 851


🏁 Script executed:

# Search for the full tilelang.layout import line
rg -B2 -A2 "from tilelang.layout import" tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 185


🏁 Script executed:

# Check if Layout is imported or if we need to verify the entire imports section
cat -n tilelang/language/gemm_op.py | sed -n '1,100p'

Repository: tile-ai/tilelang

Length of output: 4147


🏁 Script executed:

# Look for TYPE_CHECKING imports more carefully
sed -n '1,500p' tilelang/language/gemm_op.py | grep -B5 -A5 "TYPE_CHECKING\|tilelang.layout"

Repository: tile-ai/tilelang

Length of output: 429


🏁 Script executed:

# Check around the context where the tilelang.layout import appears
sed -n '180,220p' tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 1418


🏁 Script executed:

# Let's see the function around line 327-331 more completely
sed -n '325,345p' tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 758


🏁 Script executed:

# Check if there's a pyproject.toml or .flake8 to understand the Ruff config
fd -t f "pyproject.toml|ruff.toml|.flake8"

Repository: tile-ai/tilelang

Length of output: 84


🏁 Script executed:

# Check the ruff configuration
cat ruff.toml

Repository: tile-ai/tilelang

Length of output: 1137


🏁 Script executed:

# Also check pyproject.toml for any ruff config
cat pyproject.toml | grep -A 50 "tool.ruff"

Repository: tile-ai/tilelang

Length of output: 1616


🏁 Script executed:

# Let's verify if Layout is actually being used elsewhere in the file to understand the context better
rg "Layout" tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 120


🏁 Script executed:

# Let's check what the function actually returns to understand the full context
sed -n '327,380p' tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 1534


🏁 Script executed:

# Verify if Layout is referenced anywhere else in the function body
sed -n '325,380p' tilelang/language/gemm_op.py | grep -i "layout"

Repository: tile-ai/tilelang

Length of output: 368


🏁 Script executed:

# Let's check what make_mma_store_layout actually returns
fd -t f "\.py$" -path "*tcgen05*" | head -5

Repository: tile-ai/tilelang

Length of output: 228


🏁 Script executed:

# Check the TensorCoreIntrinEmitter class to see what make_mma_store_layout returns
rg "def make_mma_store_layout" -A 5

Repository: tile-ai/tilelang

Length of output: 2844


🏁 Script executed:

# Check imports in tcgen05_macro_generator.py
head -50 tilelang/intrinsics/tcgen05_macro_generator.py | grep -E "^import|^from"

Repository: tile-ai/tilelang

Length of output: 487


🏁 Script executed:

# Check if Layout is imported there
rg "Layout" tilelang/intrinsics/tcgen05_macro_generator.py | head -10

Repository: tile-ai/tilelang

Length of output: 576


🏁 Script executed:

# Check if TYPE_CHECKING is used anywhere in the file
rg "TYPE_CHECKING" tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Final verification: check the exact Ruff behavior with F821 and future annotations
rg "from __future__ import annotations" tilelang/language/gemm_op.py

Repository: tile-ai/tilelang

Length of output: 94


🏁 Script executed:

# Let's do a final check to see how Layout is used across the codebase to understand the import pattern
rg "from tilelang.layout import Layout" | head -10

Repository: tile-ai/tilelang

Length of output: 474


Import Layout or use a TYPE_CHECKING import to avoid the F821 linting error.

The return type annotation references Layout which is not imported in this module. Although the string annotation from from __future__ import annotations prevents runtime errors, Ruff's Pyflakes (F821) check will flag the undefined name. You can fix this by adding from tilelang.layout import Layout at the module level or using a TYPE_CHECKING block if the import is only needed for the annotation.

🧰 Tools
🪛 Ruff (0.15.6)

[error] 331-331: Undefined name Layout

(F821)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/gemm_op.py` around lines 327 - 331, The return annotation
"Layout" in make_blockscaled_gemm_layout is not imported and triggers Pyflakes
F821; add an import for Layout (e.g., from tilelang.layout import Layout) at
module level or, if you want to avoid runtime imports, wrap it in a
TYPE_CHECKING block (from typing import TYPE_CHECKING; if TYPE_CHECKING: from
tilelang.layout import Layout) so the symbol exists for linting while preserving
runtime behavior.

Comment thread tilelang/language/tir/op.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (3)
tilelang/language/gemm_op.py (2)

330-334: ⚠️ Potential issue | 🟡 Minor

Import Layout for the return annotation.

Line 334 still references Layout without bringing it into the module namespace, so Ruff keeps reporting F821 here. A TYPE_CHECKING import is enough if you want to avoid a runtime import.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/gemm_op.py` around lines 330 - 334, The function
make_blockscaled_gemm_layout uses the return annotation "Layout" but Layout is
not imported, causing F821; add an import for Layout into this module
(preferably inside an if TYPE_CHECKING: block) so the name is available for type
checking without a runtime dependency; reference the Layout symbol and the
make_blockscaled_gemm_layout function when adding the import.

295-302: ⚠️ Potential issue | 🟠 Major

Reject mixed FP8 operand formats here instead of mirroring A onto B.

b_dtype=a_dtype still silently encodes B as whatever A uses. A mixed E4M3/E5M2 call will build the wrong block-scaled descriptor instead of failing fast.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/gemm_op.py` around lines 295 - 302, Currently the code sets
b_dtype implicitly to a_dtype when creating TensorCoreIntrinEmitter; instead,
read B's actual dtype (use B_region.buffer.dtype) into b_dtype and add a
validation: if either a_dtype or b_dtype is an FP8 format (e.g., "e4m3" or
"e5m2") and a_dtype != b_dtype, raise an explicit error (ValueError/Assertion)
rejecting mixed FP8 operand formats; keep
TensorCoreIntrinEmitter(a_dtype=a_dtype, b_dtype=b_dtype,
accum_dtype=accum_dtype) after this check so mixed FP8s fail fast rather than
silently mirroring A onto B.
examples/gemm_sm100/gemm_mxfp8_blockscaled.py (1)

36-38: ⚠️ Potential issue | 🟠 Major

Make the example’s tile-size assumptions explicit.

sf_load_period still depends on floor division and can collapse to 0 or skip the last partial pack, while sfa_num_chunks / sfb_num_chunks still drop any tail below 128 elements. As written, this schedule only holds when block_K evenly divides 4 * sf_granularity_k and block_M / block_N are multiples of 128.

Also applies to: 62-65

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 36 - 38, The
schedule assumes tile sizes divide certain pack/granularity values but currently
uses floor division and drops tails; fix by making assumptions explicit and
handling tails: compute sf_load_period = T.ceildiv(sf_granularity_k * 4,
block_K) (not floor) and ensure it is at least 1, compute k_iters = T.ceildiv(K,
block_K) (already present) and use ceildiv for sfa_num_chunks/sfb_num_chunks
(e.g., T.ceildiv(block_M, 128) and T.ceildiv(block_N, 128)) so partial
128-element tails are included, or add explicit assertions that block_K evenly
divides 4*sf_granularity_k and block_M/block_N are multiples of 128 (raise
helpful error messages referencing sf_load_period, sf_granularity_k, block_K,
sfa_num_chunks, sfb_num_chunks, block_M, block_N) if you prefer to keep the
simpler schedule.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Around line 142-143: sf_a_id and sf_b_id are computed from sf_load_period but
must be derived from the scale-factor granularity (sf_granularity_k) so SF IDs
stay within 0..3 when a scale factor spans multiple block_K tiles; change the
calculation of sf_a_id and sf_b_id to use the granularity-based index (e.g.,
floor(k * block_K / sf_granularity_k) mod sf_load_period or equivalent) so each
packed E8M0 scale factor is reused across the correct number of K tiles; update
the assignments near where sf_a_id and sf_b_id are set and ensure the logic
aligns with tcgen05 expectations and uses sf_granularity_k instead of
sf_load_period directly.

In `@tilelang/language/gemm_op.py`:
- Around line 240-241: Validate that the scalar-field IDs are within the
documented 0–3 range before packing them into the two-bit fields: add explicit
checks for sf_a_id and sf_b_id (the places where the tcgen05 / block-scaled
descriptor is built) and raise a clear exception (e.g., ValueError) if either is
outside 0..3; apply the same guarded checks at the other occurrence referenced
(the second place handling sf_a_id/sf_b_id around the later block) so invalid
constants fail fast instead of producing an incorrect descriptor.
- Around line 284-293: Validate that A_shape, B_shape and C_shape are consistent
before creating the TensorCoreIntrinEmitter: compute M = int(C_shape[0]), N =
int(C_shape[1]) and K_expected = int(A_shape[-2] if transpose_A else
A_shape[-1]); also compute K_from_B = int(B_shape[-1] if (not transpose_B) else
B_shape[-2]) (or the symmetric form used in this codepath) and verify K_expected
== K_from_B, and that A provides M rows and B provides N columns consistent with
C (e.g. A_shape[0 or -2] matches M depending on transpose_A, and B_shape[1 or
-1] matches N depending on transpose_B); if any check fails, raise an
informative exception (or assert) before instantiating TensorCoreIntrinEmitter
so the emitter is configured only with matching tile dimensions.

---

Duplicate comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Around line 36-38: The schedule assumes tile sizes divide certain
pack/granularity values but currently uses floor division and drops tails; fix
by making assumptions explicit and handling tails: compute sf_load_period =
T.ceildiv(sf_granularity_k * 4, block_K) (not floor) and ensure it is at least
1, compute k_iters = T.ceildiv(K, block_K) (already present) and use ceildiv for
sfa_num_chunks/sfb_num_chunks (e.g., T.ceildiv(block_M, 128) and
T.ceildiv(block_N, 128)) so partial 128-element tails are included, or add
explicit assertions that block_K evenly divides 4*sf_granularity_k and
block_M/block_N are multiples of 128 (raise helpful error messages referencing
sf_load_period, sf_granularity_k, block_K, sfa_num_chunks, sfb_num_chunks,
block_M, block_N) if you prefer to keep the simpler schedule.

In `@tilelang/language/gemm_op.py`:
- Around line 330-334: The function make_blockscaled_gemm_layout uses the return
annotation "Layout" but Layout is not imported, causing F821; add an import for
Layout into this module (preferably inside an if TYPE_CHECKING: block) so the
name is available for type checking without a runtime dependency; reference the
Layout symbol and the make_blockscaled_gemm_layout function when adding the
import.
- Around line 295-302: Currently the code sets b_dtype implicitly to a_dtype
when creating TensorCoreIntrinEmitter; instead, read B's actual dtype (use
B_region.buffer.dtype) into b_dtype and add a validation: if either a_dtype or
b_dtype is an FP8 format (e.g., "e4m3" or "e5m2") and a_dtype != b_dtype, raise
an explicit error (ValueError/Assertion) rejecting mixed FP8 operand formats;
keep TensorCoreIntrinEmitter(a_dtype=a_dtype, b_dtype=b_dtype,
accum_dtype=accum_dtype) after this check so mixed FP8s fail fast rather than
silently mirroring A onto B.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b3f2d616-e7be-4c47-8e70-e3fcfa9cfc4f

📥 Commits

Reviewing files that changed from the base of the PR and between a1bb0ee and 9966ba5.

📒 Files selected for processing (2)
  • examples/gemm_sm100/gemm_mxfp8_blockscaled.py
  • tilelang/language/gemm_op.py

Comment on lines +142 to +143
sf_a_id=k % sf_load_period,
sf_b_id=k % sf_load_period,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Derive sf_*_id from scale-factor granularity, not the pack reload period.

The tcgen05 block-scaled path only accepts SF IDs 0-3, but k % sf_load_period only works when one block_K tile maps to one scale factor. If block_K < sf_granularity_k, this starts emitting 0..7, 0..15, etc. instead of reusing each packed E8M0 value for multiple K tiles.

One safe way to keep this parameterized
-    sf_load_period = sf_granularity_k * 4 // block_K
+    assert sf_granularity_k % block_K == 0, (
+        "This example currently requires block_K to evenly divide sf_granularity_k"
+    )
+    sf_tiles_per_value = sf_granularity_k // block_K
+    sf_load_period = sf_tiles_per_value * 4
@@
-                    sf_a_id=k % sf_load_period,
-                    sf_b_id=k % sf_load_period,
+                    sf_a_id=(k // sf_tiles_per_value) % 4,
+                    sf_b_id=(k // sf_tiles_per_value) % 4,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 142 - 143,
sf_a_id and sf_b_id are computed from sf_load_period but must be derived from
the scale-factor granularity (sf_granularity_k) so SF IDs stay within 0..3 when
a scale factor spans multiple block_K tiles; change the calculation of sf_a_id
and sf_b_id to use the granularity-based index (e.g., floor(k * block_K /
sf_granularity_k) mod sf_load_period or equivalent) so each packed E8M0 scale
factor is reused across the correct number of K tiles; update the assignments
near where sf_a_id and sf_b_id are set and ensure the logic aligns with tcgen05
expectations and uses sf_granularity_k instead of sf_load_period directly.

Comment on lines +240 to +241
sf_a_id: int = 0,
sf_b_id: int = 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate constant sf_a_id / sf_b_id before passing them through.

These IDs are documented as 0-3, and the tcgen05 block-scaled descriptor packs them into two-bit fields. A constant value outside that range will silently produce the wrong instruction descriptor here instead of failing fast.

Small guard for the public API
     SFA_tmem = legalize(SFA_tmem)
     SFB_tmem = legalize(SFB_tmem)
     mbar = legalize(mbar) if mbar is not None else None
+    if isinstance(sf_a_id, int) and not 0 <= sf_a_id <= 3:
+        raise ValueError(f"sf_a_id must be in [0, 3], got {sf_a_id}")
+    if isinstance(sf_b_id, int) and not 0 <= sf_b_id <= 3:
+        raise ValueError(f"sf_b_id must be in [0, 3], got {sf_b_id}")

Also applies to: 324-326

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/gemm_op.py` around lines 240 - 241, Validate that the
scalar-field IDs are within the documented 0–3 range before packing them into
the two-bit fields: add explicit checks for sf_a_id and sf_b_id (the places
where the tcgen05 / block-scaled descriptor is built) and raise a clear
exception (e.g., ValueError) if either is outside 0..3; apply the same guarded
checks at the other occurrence referenced (the second place handling
sf_a_id/sf_b_id around the later block) so invalid constants fail fast instead
of producing an incorrect descriptor.

Comment thread tilelang/language/gemm_op.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (5)
src/op/builtin.h (1)

352-365: Minor formatting inconsistency: extra leading space on line 365.

Line 365 has an extra leading space before TVM_DLL compared to lines 355 and 360. This doesn't affect functionality but breaks consistency with surrounding declarations.

-/*!
- * \brief tvm intrinsic for scale factor warp transpose in shared memory.
- */
- TVM_DLL const Op &ptx_tcgen05_sf_warp_transpose();
+/*!
+ * \brief tvm intrinsic for scale factor warp transpose in shared memory.
+ */
+TVM_DLL const Op &ptx_tcgen05_sf_warp_transpose();
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/builtin.h` around lines 352 - 365, Remove the extra leading space
before the declaration of the ptx_tcgen05_sf_warp_transpose intrinsic so it
aligns with the other TVM_DLL declarations (make the line start with "TVM_DLL
const Op &ptx_tcgen05_sf_warp_transpose();"); this fixes the formatting
inconsistency in the block containing ptx_tcgen05_mma_blockscaled_ss and
ptx_tcgen05_cp.
src/op/builtin.cc (1)

212-226: Minor formatting inconsistencies.

The registrations are functionally correct, but there are minor formatting issues:

  • Line 225: Extra space before Integer compared to other registrations.
  • Line 226: Trailing whitespace.
 TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_sf_warp_transpose)
     .set_num_inputs(1)
     .set_attr<TCallEffectKind>("TCallEffectKind",
-                                Integer(CallEffectKind::kOpaque));
-                               
+                               Integer(CallEffectKind::kOpaque));
+
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/builtin.cc` around lines 212 - 226, The three TIR builtin
registrations (ptx_tcgen05_mma_blockscaled_ss, ptx_tcgen05_cp,
ptx_tcgen05_sf_warp_transpose) have minor formatting inconsistencies: remove the
extra space before Integer in the .set_attr call for
ptx_tcgen05_sf_warp_transpose so it matches the other registrations and delete
the trailing whitespace at the end of that line; ensure the
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
call formatting is consistent across all three builtins.
src/op/gemm_py.cc (1)

90-102: Partial scale-factor specification may cause subtle issues.

The parsing allows specifying only some of the block-scaled fields (e.g., sfaRegion_ without sfbRegion_). Since is_blockscaled in gemm_base.py checks that both regions are present, partial specification silently falls back to non-blockscaled mode rather than raising an error.

Consider either:

  1. Requiring all four fields together when any is present, or
  2. Documenting that partial specification is valid and falls back to standard GEMM.
   // Block-scaled GEMM: optional SFA, SFB regions and scale factor IDs
   if (args.size() > 19) {
     node->sfaRegion_ = NormalizeToBufferRegion(args[19]);
   }
   if (args.size() > 20) {
     node->sfbRegion_ = NormalizeToBufferRegion(args[20]);
   }
+  // Validate: either both SF regions are set, or neither
+  if (node->sfaRegion_.defined() != node->sfbRegion_.defined()) {
+    ICHECK(false) << "Block-scaled GEMM requires both sfaRegion and sfbRegion";
+  }
   if (args.size() > 21) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm_py.cc` around lines 90 - 102, The parser in gemm_py.cc allows
specifying sfaRegion_, sfbRegion_, sfAId_, sfBId_ independently which can
produce a partial block-scaled specification that gemm_base.py's is_blockscaled
(which expects both regions) will treat as not blockscaled; update the parsing
logic so that when any of the four related args (sfaRegion_, sfbRegion_, sfAId_,
sfBId_) is present you require all four — e.g., after reading args[19..22]
validate presence of the other fields and raise an error or throw a
runtime/parse exception if any are missing; locate the assignments to
sfaRegion_, sfbRegion_, sfAId_, sfBId_ and the use of NormalizeToBufferRegion to
implement this all-or-none validation, referencing is_blockscaled in
gemm_base.py as the behavioral expectation.
src/op/gemm_py.h (1)

39-41: Consider consistent naming: sfaRegion_ vs sfAId_.

The naming uses different casing patterns:

  • sfaRegion_, sfbRegion_ (all lowercase "sfa"/"sfb")
  • sfAId_, sfBId_ (camelCase "sfA"/"sfB")

This propagates to the reflection bindings (sfaRegion vs sfAId). For API consistency, consider aligning to one pattern (e.g., sfARegion_/sfBRegion_ or sfaId_/sfbId_).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/gemm_py.h` around lines 39 - 41, The field names for block-scaled GEMM
are inconsistent: you have sfaRegion_ / sfbRegion_ but sfAId_ / sfBId_ (and the
reflection bindings use mixed forms like sfaRegion vs sfAId); pick one
consistent casing (e.g., sfARegion_ / sfBRegion_ and sfAId_ / sfBId_ or
sfaRegion_ / sfbRegion_ and sfaId_ / sfbId_) and rename the symbols accordingly
throughout the class and its reflection bindings so declarations (sfaRegion_,
sfbRegion_, sfAId_, sfBId_), any getters/setters, and reflection names
(sfaRegion, sfAId, etc.) all match the chosen pattern; ensure you update all
usages and serialization/reflection code that references these identifiers
(search for sfaRegion_, sfbRegion_, sfAId_, sfBId_, sfaRegion, sfbRegion, sfAId,
sfBId).
tilelang/language/gemm_op.py (1)

320-321: Minor: Remove unused mbar_arg variable.

Line 321 assigns mbar_arg = mbar but line 351 passes mbar directly, making mbar_arg unused.

Cleanup
     assert mbar is not None, "mbar is required for blockscaled_gemm"
-    mbar_arg = mbar 
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/language/gemm_op.py` around lines 320 - 321, The assignment
"mbar_arg = mbar" is redundant because the code later passes "mbar" directly;
remove the unused variable by deleting the "mbar_arg = mbar" line (keeping the
assertion assert mbar is not None, "mbar is required for blockscaled_gemm") and
verify no other references to mbar_arg exist in the surrounding functions or the
blockscaled_gemm call sites.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/op/builtin.cc`:
- Around line 212-226: The three TIR builtin registrations
(ptx_tcgen05_mma_blockscaled_ss, ptx_tcgen05_cp, ptx_tcgen05_sf_warp_transpose)
have minor formatting inconsistencies: remove the extra space before Integer in
the .set_attr call for ptx_tcgen05_sf_warp_transpose so it matches the other
registrations and delete the trailing whitespace at the end of that line; ensure
the .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); call formatting is consistent across all
three builtins.

In `@src/op/builtin.h`:
- Around line 352-365: Remove the extra leading space before the declaration of
the ptx_tcgen05_sf_warp_transpose intrinsic so it aligns with the other TVM_DLL
declarations (make the line start with "TVM_DLL const Op
&ptx_tcgen05_sf_warp_transpose();"); this fixes the formatting inconsistency in
the block containing ptx_tcgen05_mma_blockscaled_ss and ptx_tcgen05_cp.

In `@src/op/gemm_py.cc`:
- Around line 90-102: The parser in gemm_py.cc allows specifying sfaRegion_,
sfbRegion_, sfAId_, sfBId_ independently which can produce a partial
block-scaled specification that gemm_base.py's is_blockscaled (which expects
both regions) will treat as not blockscaled; update the parsing logic so that
when any of the four related args (sfaRegion_, sfbRegion_, sfAId_, sfBId_) is
present you require all four — e.g., after reading args[19..22] validate
presence of the other fields and raise an error or throw a runtime/parse
exception if any are missing; locate the assignments to sfaRegion_, sfbRegion_,
sfAId_, sfBId_ and the use of NormalizeToBufferRegion to implement this
all-or-none validation, referencing is_blockscaled in gemm_base.py as the
behavioral expectation.

In `@src/op/gemm_py.h`:
- Around line 39-41: The field names for block-scaled GEMM are inconsistent: you
have sfaRegion_ / sfbRegion_ but sfAId_ / sfBId_ (and the reflection bindings
use mixed forms like sfaRegion vs sfAId); pick one consistent casing (e.g.,
sfARegion_ / sfBRegion_ and sfAId_ / sfBId_ or sfaRegion_ / sfbRegion_ and
sfaId_ / sfbId_) and rename the symbols accordingly throughout the class and its
reflection bindings so declarations (sfaRegion_, sfbRegion_, sfAId_, sfBId_),
any getters/setters, and reflection names (sfaRegion, sfAId, etc.) all match the
chosen pattern; ensure you update all usages and serialization/reflection code
that references these identifiers (search for sfaRegion_, sfbRegion_, sfAId_,
sfBId_, sfaRegion, sfbRegion, sfAId, sfBId).

In `@tilelang/language/gemm_op.py`:
- Around line 320-321: The assignment "mbar_arg = mbar" is redundant because the
code later passes "mbar" directly; remove the unused variable by deleting the
"mbar_arg = mbar" line (keeping the assertion assert mbar is not None, "mbar is
required for blockscaled_gemm") and verify no other references to mbar_arg exist
in the surrounding functions or the blockscaled_gemm call sites.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f8addcf4-750a-4654-8fd8-ab0a3cd86b35

📥 Commits

Reviewing files that changed from the base of the PR and between 9966ba5 and 1b5032a.

📒 Files selected for processing (14)
  • examples/gemm_sm100/gemm_mxfp8_blockscaled.py
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/gemm_py.cc
  • src/op/gemm_py.h
  • src/target/codegen_cuda.cc
  • tilelang/language/ast/ir.py
  • tilelang/language/builtin.py
  • tilelang/language/gemm_op.py
  • tilelang/language/tir/ir.py
  • tilelang/language/tir/op.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_base.py
  • tilelang/tileop/gemm/gemm_tcgen05.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/language/tir/ir.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
examples/gemm_sm100/gemm_mxfp8_blockscaled.py (2)

37-37: ⚠️ Potential issue | 🟠 Major

Derive SF IDs from granularity, not reload-period index.

At Line 128 and Line 129, k % sf_load_period can exceed 3 when one SF value spans multiple K tiles (e.g., block_K < sf_granularity_k). For blockscaled_gemm, SF IDs need to remain in 0..3.

Safer parameterized mapping
-    sf_load_period = sf_granularity_k * 4 // block_K
+    assert sf_granularity_k % block_K == 0, (
+        "This kernel currently requires block_K to evenly divide sf_granularity_k"
+    )
+    sf_tiles_per_value = sf_granularity_k // block_K
+    sf_load_period = sf_tiles_per_value * 4
@@
-                    sf_a_id=k % sf_load_period,
-                    sf_b_id=k % sf_load_period,
+                    sf_a_id=(k // sf_tiles_per_value) % 4,
+                    sf_b_id=(k // sf_tiles_per_value) % 4,

Use this read-only check to confirm the API constraint and call-site mismatch:

#!/bin/bash
set -euo pipefail

# Verify blockscaled_gemm API docs/args for SF ID constraints in TileLang
rg -n -C3 'sf_a_id|sf_b_id|Scale factor ID|0-3' tilelang/language/gemm_op.py

# Verify current example call-site derivation
rg -n -C3 'sf_load_period|sf_a_id=|sf_b_id=' examples/gemm_sm100/gemm_mxfp8_blockscaled.py

Also applies to: 119-129

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` at line 37, The current
derivation uses sf_load_period and k % sf_load_period which can yield values >3;
instead compute SF IDs from granularity: derive the SF tile index as (k //
block_K) // sf_granularity_k and then take modulo 4 to get an ID in 0..3, e.g.
sf_tile_index = ((k // block_K) // sf_granularity_k) and sf_id = sf_tile_index %
4, and use that for sf_a_id and sf_b_id; update the code that computes
sf_load_period and the places assigning sf_a_id/sf_b_id to use this
granularity-based mapping (referencing sf_load_period, sf_granularity_k,
block_K, sf_a_id, sf_b_id).

35-37: ⚠️ Potential issue | 🟠 Major

Guard the current “happy-path only” assumptions explicitly.

Line 37 can yield sf_load_period == 0, and Lines 61-62 floor-truncate UTCCP chunk counts. That can trigger divide-by-zero at Lines 93/111 or silently skip SF data when block sizes change.

Suggested guardrails
     k_iters = T.ceildiv(K, block_K)
     # 4 packed E8M0 per uint32 → load every 4 stages
-    sf_load_period = sf_granularity_k * 4 // block_K
+    assert block_K > 0 and sf_granularity_k > 0
+    assert (sf_granularity_k * 4) % block_K == 0, (
+        "block_K must evenly divide 4 * sf_granularity_k"
+    )
+    sf_load_period = sf_granularity_k * 4 // block_K
@@
-        sfa_num_chunks = block_M // 128  # number of 128-element UTCCP chunks
-        sfb_num_chunks = block_N // 128
+        assert block_M % 128 == 0 and block_N % 128 == 0, (
+            "UTCCP path currently requires block_M and block_N to be multiples of 128"
+        )
+        sfa_num_chunks = block_M // 128
+        sfb_num_chunks = block_N // 128

Also applies to: 61-64, 93-94, 111-111

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 35 - 37, The
computation for sf_load_period and subsequent UTCCP chunk counts assumes
non-zero divisors and can produce sf_load_period == 0 (or zero chunk counts)
leading to divide-by-zero or dropped SF data; update the calculations that
produce sf_load_period (currently sf_load_period = sf_granularity_k * 4 //
block_K) and any derived UTCCP chunk counts to defensively enforce a minimum of
1 (e.g., wrap with max(1, ...)), and add explicit guards or assertions before
using these values as divisors so code using k_iters, sf_load_period, and the
UTCCP chunk variables cannot divide by zero or silently skip SF processing.
Ensure you modify all uses referenced (sf_load_period, sf_granularity_k, block_K
and the chunk count calculations) so they never become zero.
🧹 Nitpick comments (2)
examples/gemm_sm100/gemm_mxfp8_blockscaled.py (1)

235-235: Avoid printing full 8192×8192 tensors in the example path.

Line 235 can dominate runtime/log volume and make benchmark output hard to use. Prefer shape/stats or gated debug printing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` at line 235, The example
currently prints full tensors c and ref_c (print(f"{c=}, {ref_c=}")), which will
log huge 8192×8192 arrays; replace that line to avoid dumping full tensors by
printing compact diagnostics instead (e.g., shapes and summary statistics like
c.shape, c.dtype, c.mean().item(), c.std().item(), and a max absolute difference
between c and ref_c) or wrap the full-tensor print behind a verbose/debug flag
so normal runs only emit lightweight summaries; update the print call in
gemm_mxfp8_blockscaled.py at the location of variables c and ref_c accordingly.
src/op/copy.cc (1)

1319-1326: Consider applying buffer remapping for dynamic shared memory.

The code uses src->data directly without checking T.buffer_remap. For dynamic shared memory (shared.dyn), the buffer may need remapping similar to other copy paths (e.g., LowerBulkCopy at line 1684 and LowerLDSMCopy at line 1102).

♻️ Proposed fix
+    // Apply buffer remapping for dynamic shared memory
+    Buffer smem_buf = src;
+    if (T.buffer_remap.count(src)) {
+      smem_buf = T.buffer_remap.at(src);
+    }
+
     // SMEM base access_ptr
-    PrimExpr ptype = tir::TypeAnnotation(src->dtype);
+    PrimExpr ptype = tir::TypeAnnotation(smem_buf->dtype);
     auto make_smem_ptr = [&](PrimExpr elem_offset) {
       return Call(DataType::Handle(), builtin::tvm_access_ptr(),
-                  {ptype, src->data, elem_offset,
+                  {ptype, smem_buf->data, elem_offset,
                    IntImm(DataType::Int(32), ELEMENTS_PER_CP),
                    IntImm(DataType::Int(32), 1 /*read*/)});
     };
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 1319 - 1326, The shared-memory pointer builder
make_smem_ptr currently uses src->data directly, which fails to apply buffer
remapping for dynamic shared memory; modify the lambda (and the surrounding
logic where it captures src) to check for T.buffer_remap (the same remapping
used in LowerBulkCopy/LowerLDSMCopy) and use the remapped buffer handle instead
of src->data when buffer_remap is present (i.e., for shared.dyn), so the Call to
tvm_access_ptr receives the remapped buffer pointer rather than the original
src->data.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/copy.cc`:
- Around line 1341-1344: Guard against an empty stmts vector before accessing
stmts[0]: check if stmts.empty() (derived from total_elements/num_calls) and
either return an appropriate no-op/empty Stmt early or set body to an explicit
empty statement instead of dereferencing stmts[0]; update the block around
variables stmts, body and the SeqStmt construction so the loop only runs when
stmts is non-empty to avoid undefined behavior.

---

Duplicate comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Line 37: The current derivation uses sf_load_period and k % sf_load_period
which can yield values >3; instead compute SF IDs from granularity: derive the
SF tile index as (k // block_K) // sf_granularity_k and then take modulo 4 to
get an ID in 0..3, e.g. sf_tile_index = ((k // block_K) // sf_granularity_k) and
sf_id = sf_tile_index % 4, and use that for sf_a_id and sf_b_id; update the code
that computes sf_load_period and the places assigning sf_a_id/sf_b_id to use
this granularity-based mapping (referencing sf_load_period, sf_granularity_k,
block_K, sf_a_id, sf_b_id).
- Around line 35-37: The computation for sf_load_period and subsequent UTCCP
chunk counts assumes non-zero divisors and can produce sf_load_period == 0 (or
zero chunk counts) leading to divide-by-zero or dropped SF data; update the
calculations that produce sf_load_period (currently sf_load_period =
sf_granularity_k * 4 // block_K) and any derived UTCCP chunk counts to
defensively enforce a minimum of 1 (e.g., wrap with max(1, ...)), and add
explicit guards or assertions before using these values as divisors so code
using k_iters, sf_load_period, and the UTCCP chunk variables cannot divide by
zero or silently skip SF processing. Ensure you modify all uses referenced
(sf_load_period, sf_granularity_k, block_K and the chunk count calculations) so
they never become zero.

---

Nitpick comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Line 235: The example currently prints full tensors c and ref_c (print(f"{c=},
{ref_c=}")), which will log huge 8192×8192 arrays; replace that line to avoid
dumping full tensors by printing compact diagnostics instead (e.g., shapes and
summary statistics like c.shape, c.dtype, c.mean().item(), c.std().item(), and a
max absolute difference between c and ref_c) or wrap the full-tensor print
behind a verbose/debug flag so normal runs only emit lightweight summaries;
update the print call in gemm_mxfp8_blockscaled.py at the location of variables
c and ref_c accordingly.

In `@src/op/copy.cc`:
- Around line 1319-1326: The shared-memory pointer builder make_smem_ptr
currently uses src->data directly, which fails to apply buffer remapping for
dynamic shared memory; modify the lambda (and the surrounding logic where it
captures src) to check for T.buffer_remap (the same remapping used in
LowerBulkCopy/LowerLDSMCopy) and use the remapped buffer handle instead of
src->data when buffer_remap is present (i.e., for shared.dyn), so the Call to
tvm_access_ptr receives the remapped buffer pointer rather than the original
src->data.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 395739bb-5626-4e6f-b16b-88a2c241c98b

📥 Commits

Reviewing files that changed from the base of the PR and between 1b5032a and 0a5394e.

📒 Files selected for processing (3)
  • examples/gemm_sm100/gemm_mxfp8_blockscaled.py
  • src/op/copy.cc
  • src/op/copy.h

Comment thread src/op/copy.cc
Comment on lines +1341 to +1344
Stmt body = stmts[0];
for (size_t i = 1; i < stmts.size(); ++i) {
body = SeqStmt({body, stmts[i]});
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard against empty stmts vector to avoid undefined behavior.

If total_elements is 0 (e.g., a zero-extent region), num_calls will be 0 and stmts will be empty. Accessing stmts[0] on line 1341 would then be undefined behavior.

🛡️ Proposed fix
     // Build a sequence of ptx_tcgen05_cp calls
     std::vector<Stmt> stmts;
     for (int i = 0; i < num_calls; ++i) {
       PrimExpr smem_ptr = make_smem_ptr(flat_offset + i * ELEMENTS_PER_CP);
       PrimExpr col_offset = IntImm(DataType::Int(32), i * COLS_PER_CP);
       stmts.push_back(Evaluate(
           Call(DataType::Void(), ptx_tcgen05_cp(),
                {smem_ptr, tmem_ptr, col_offset})));
     }

+    if (stmts.empty()) {
+      return Evaluate(0);  // No-op for zero-element copy
+    }
+
     Stmt body = stmts[0];
     for (size_t i = 1; i < stmts.size(); ++i) {
       body = SeqStmt({body, stmts[i]});
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 1341 - 1344, Guard against an empty stmts vector
before accessing stmts[0]: check if stmts.empty() (derived from
total_elements/num_calls) and either return an appropriate no-op/empty Stmt
early or set body to an explicit empty statement instead of dereferencing
stmts[0]; update the block around variables stmts, body and the SeqStmt
construction so the loop only runs when stmts is non-empty to avoid undefined
behavior.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py`:
- Around line 226-230: Cosine similarity is insufficient; add an elementwise
absolute and relative error check between c and ref_c after computing sim:
convert c and ref_c to a higher-precision dtype (e.g., float32), compute
absolute errors (abs_err = |c - ref_c|) and relative errors (rel_err = abs_err /
(|ref_c| + eps)), get max_abs_err and max_rel_err, and assert they are below
chosen thresholds (e.g., max_abs_err < 1e-2 and max_rel_err < 1e-2); keep the
existing cosine_similarity check and include clear failure messages referencing
c, ref_c, sim, max_abs_err, and max_rel_err so blockscaled_gemm_ref / c
mismatches are detected even when angles match.
- Around line 54-57: The code computes sfa_num_chunks and sfb_num_chunks and
allocates SFA_tmem/SFB_tmem assuming block_M and block_N are exact multiples of
128, but the transpose only processes block_M // 128 and block_N // 128 chunks;
add explicit runtime assertions that block_M % 128 == 0 and block_N % 128 == 0
(with clear messages referencing block_M and block_N) before computing
sfa_num_chunks/sfb_num_chunks and allocating SFA_tmem/SFB_tmem to prevent silent
truncation/corruption.
- Around line 37-38: The calculation for sf_load_period using sf_granularity_k *
4 // block_K can produce scale-factor IDs (sf_a_id, sf_b_id) outside the valid
0-3 range for different block_K values; update the logic around sf_load_period,
sf_granularity_k, and the code that computes sf_a_id/sf_b_id so IDs are derived
from the scale-block structure (e.g., index within a scale-block of size
sf_granularity_k or a bounded modulo) rather than raw K-tile indices, and add a
validation/assertion that sf_a_id and sf_b_id are within 0..3 (or clamp/raise a
configuration error) before they are used in blockscaled_gemm so the contract is
always satisfied.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 018d55e1-e3e6-4926-8b00-08f635ad4831

📥 Commits

Reviewing files that changed from the base of the PR and between 0a5394e and 4bc7175.

📒 Files selected for processing (1)
  • examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py

Comment on lines +37 to +38
sf_load_period = sf_granularity_k * 4 // block_K

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's read the file to understand the actual configuration
cat -n examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py | head -150

Repository: tile-ai/tilelang

Length of output: 6760


🏁 Script executed:

# Also find and examine the blockscaled_gemm API definition
fd -t f "gemm_op.py" --exec grep -l "blockscaled_gemm" {} \;

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

# And check the gemm_tcgen05 file mentioned
fd -t f "gemm_tcgen05.py"

Repository: tile-ai/tilelang

Length of output: 137


🏁 Script executed:

# Check for the CUDA template file
fd -t f "tcgen_05.h"

Repository: tile-ai/tilelang

Length of output: 92


🏁 Script executed:

# Read the blockscaled_gemm API definition
grep -A 50 "def blockscaled_gemm" tilelang/language/gemm_op.py | head -70

Repository: tile-ai/tilelang

Length of output: 1862


🏁 Script executed:

# Read the lowering implementation
cat -n tilelang/tileop/gemm/gemm_tcgen05.py | sed -n '189,250p'

Repository: tile-ai/tilelang

Length of output: 1976


🏁 Script executed:

# Read the CUDA template
cat -n src/tl_templates/cuda/tcgen_05.h | sed -n '60,100p'

Repository: tile-ai/tilelang

Length of output: 2123


🏁 Script executed:

# Read the main() section of the example file to see actual parameter values
cat -n examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py | tail -100

Repository: tile-ai/tilelang

Length of output: 4694


🏁 Script executed:

# Let's verify the mathematical relationship more explicitly
# with the actual parameters from main()
python3 << 'EOF'
# Actual parameters from main()
M, N, K = 8192, 8192, 8192
block_M, block_N, block_K = 128, 256, 128
sf_granularity_k = 128

# Current formula at line 37
sf_load_period_current = sf_granularity_k * 4 // block_K
print(f"Current formula: {sf_granularity_k} * 4 // {block_K} = {sf_load_period_current}")

# Proposed formula
tiles_per_scale = sf_granularity_k // block_K
sf_load_period_proposed = tiles_per_scale * 4
print(f"Proposed formula: ({sf_granularity_k} // {block_K}) * 4 = {sf_load_period_proposed}")

# Number of scale factor blocks
sf_k_blocks = (K + sf_granularity_k - 1) // sf_granularity_k
sf_k_packed = (sf_k_blocks + 3) // 4
print(f"\nScale factor structure:")
print(f"  sf_k_blocks (unpacked): {sf_k_blocks}")
print(f"  sf_k_packed (uint32): {sf_k_packed}")
print(f"  Each uint32 covers: {sf_granularity_k * 4} K elements = {(sf_granularity_k * 4) // block_K} K-tiles")

# K iteration logic
k_iters = (K + block_K - 1) // block_K
print(f"\nK-iteration logic:")
print(f"  Total K-tile iterations: {k_iters}")
print(f"  Scale factors loaded every {sf_load_period_current} iterations")
print(f"  sf_a_id cycles: {list(range(sf_load_period_current))}")

# Scale loading points
load_points = [k for k in range(k_iters) if k % sf_load_period_current == 0]
print(f"  Scale factors loaded at k indices: {load_points[:10]}... (showing first 10)")
print(f"  Total scale loads: {len(load_points)}")

# Test with different block_K to show the problem
print(f"\n--- Testing with different block_K values ---")
for test_block_K in [64, 128, 256]:
    print(f"\nblock_K = {test_block_K}:")
    sf_load_period_test = sf_granularity_k * 4 // test_block_K
    print(f"  Current formula produces sf_load_period = {sf_load_period_test}")
    print(f"  sf_a_id would cycle through range(0, {sf_load_period_test})")
    if sf_load_period_test > 4:
        print(f"  ⚠️  INVALID: sf_a_id exceeds valid range 0-3!")
    
    if sf_granularity_k % test_block_K == 0:
        tiles_per_scale_test = sf_granularity_k // test_block_K
        sf_load_period_safe = tiles_per_scale_test * 4
        print(f"  Proposed formula produces sf_load_period = {sf_load_period_safe}")
        print(f"  sf_a_id = (k // {tiles_per_scale_test}) % 4 ✓ Always in range 0-3")
    else:
        print(f"  ⚠️  sf_granularity_k ({sf_granularity_k}) not divisible by block_K ({test_block_K})")
EOF

Repository: tile-ai/tilelang

Length of output: 1162


🏁 Script executed:

# Let's also check the blockscaled_gemm lowering to see if it validates sf_a_id/sf_b_id
grep -B 5 -A 15 "sf_a_id\|sf_b_id" tilelang/tileop/gemm/gemm_tcgen05.py | head -50

Repository: tile-ai/tilelang

Length of output: 1354


🏁 Script executed:

# Check if there are any assertions on sf_a_id in the MMA emitter
fd -t f "tcgen05" -type f | xargs grep -l "sf_a_id\|sf_b_id" 2>/dev/null

Repository: tile-ai/tilelang

Length of output: 232


🏁 Script executed:

# Let's check the actual MMA emitter implementation
grep -A 20 "def tcgen05mma_blockscaled" tilelang/tileop/mma/tcgen05_mma.py 2>/dev/null | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if tcgen05mma_blockscaled validates sf_a_id
rg "tcgen05mma_blockscaled" --type py -A 30 | head -60

Repository: tile-ai/tilelang

Length of output: 3945


🏁 Script executed:

# Check the CUDA template for any ID validation
rg "sf_a_id|sf_b_id" src/tl_templates/cuda/ -A 2 -B 2

Repository: tile-ai/tilelang

Length of output: 42


Ensure scale factor IDs remain in valid range (0-3) for all block_K configurations.

The blockscaled_gemm API strictly requires sf_a_id and sf_b_id to be in range 0-3. Line 37's formula sf_load_period = sf_granularity_k * 4 // block_K happens to be safe for the current example (block_K=128 equals sf_granularity_k=128), but will violate this contract for other block_K values. For example, if block_K=64, the formula produces sf_load_period=8, causing lines 114-115 to emit IDs up to 7. Add validation to constrain the configuration, then derive scale IDs from the scale block structure rather than raw K-tile indices.

Suggested fix
-    sf_load_period = sf_granularity_k * 4 // block_K
+    if sf_granularity_k % block_K != 0:
+        raise ValueError("sf_granularity_k must be a multiple of block_K")
+    tiles_per_scale = sf_granularity_k // block_K
+    sf_load_period = tiles_per_scale * 4
...
-                    sf_a_id=k % sf_load_period,
-                    sf_b_id=k % sf_load_period,
+                    sf_a_id=(k // tiles_per_scale) % 4,
+                    sf_b_id=(k // tiles_per_scale) % 4,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py` around lines 37 -
38, The calculation for sf_load_period using sf_granularity_k * 4 // block_K can
produce scale-factor IDs (sf_a_id, sf_b_id) outside the valid 0-3 range for
different block_K values; update the logic around sf_load_period,
sf_granularity_k, and the code that computes sf_a_id/sf_b_id so IDs are derived
from the scale-block structure (e.g., index within a scale-block of size
sf_granularity_k or a bounded modulo) rather than raw K-tile indices, and add a
validation/assertion that sf_a_id and sf_b_id are within 0..3 (or clamp/raise a
configuration error) before they are used in blockscaled_gemm so the contract is
always satisfied.

Comment on lines +54 to +57
sfa_num_chunks = block_M // 128
sfb_num_chunks = block_N // 128
SFA_tmem = T.alloc_tmem([32, sfa_num_chunks * 4], "uint32")
SFB_tmem = T.alloc_tmem([32, sfb_num_chunks * 4], "uint32")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "gemm_mxfp8_blockscaled_1_128_128.py" -type f

Repository: tile-ai/tilelang

Length of output: 117


🏁 Script executed:

wc -l ./examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py

Repository: tile-ai/tilelang

Length of output: 121


🏁 Script executed:

cat -n ./examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py

Repository: tile-ai/tilelang

Length of output: 11309


Enforce multiples of 128 for block_M and block_N.

Lines 54–57 allocate TMEM based on floor-divided chunk counts, but the transpose loop (lines 99–102) only processes block_M // 128 and block_N // 128 chunks of 128 elements each. When block_M or block_N are not multiples of 128, the remaining elements are never transposed but are still copied wholesale into undersized TMEM buffers at lines 103–104, leading to data corruption. Add explicit assertions that both must be multiples of 128:

+        if block_M % 128 != 0 or block_N % 128 != 0:
+            raise ValueError("block_M and block_N must be multiples of 128")
         sfa_num_chunks = block_M // 128
         sfb_num_chunks = block_N // 128
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sfa_num_chunks = block_M // 128
sfb_num_chunks = block_N // 128
SFA_tmem = T.alloc_tmem([32, sfa_num_chunks * 4], "uint32")
SFB_tmem = T.alloc_tmem([32, sfb_num_chunks * 4], "uint32")
if block_M % 128 != 0 or block_N % 128 != 0:
raise ValueError("block_M and block_N must be multiples of 128")
sfa_num_chunks = block_M // 128
sfb_num_chunks = block_N // 128
SFA_tmem = T.alloc_tmem([32, sfa_num_chunks * 4], "uint32")
SFB_tmem = T.alloc_tmem([32, sfb_num_chunks * 4], "uint32")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py` around lines 54 -
57, The code computes sfa_num_chunks and sfb_num_chunks and allocates
SFA_tmem/SFB_tmem assuming block_M and block_N are exact multiples of 128, but
the transpose only processes block_M // 128 and block_N // 128 chunks; add
explicit runtime assertions that block_M % 128 == 0 and block_N % 128 == 0 (with
clear messages referencing block_M and block_N) before computing
sfa_num_chunks/sfb_num_chunks and allocating SFA_tmem/SFB_tmem to prevent silent
truncation/corruption.

Comment on lines +226 to +230
ref_c = blockscaled_gemm_ref(a, b, sfa_unpacked[:, :sf_k_blocks], sfb_unpacked_coarse, sf_granularity_k, sf_granularity_n).to(torch.bfloat16)
sim = cosine_similarity(c, ref_c)
print(f"Output shape: {c.shape}, dtype: {c.dtype}")
print(f"Cosine similarity: {sim.item():.6f}")
assert sim > 0.99, f"Cosine similarity too low: {sim.item():.6f}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cosine similarity alone is too weak for scale-factor validation.

A uniform exponent or magnitude bug on the block-scaling path keeps the angle almost unchanged, so Lines 227-230 can pass while the output is still numerically wrong. Please add an elementwise absolute/relative-error check alongside the cosine gate.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py` around lines 226 -
230, Cosine similarity is insufficient; add an elementwise absolute and relative
error check between c and ref_c after computing sim: convert c and ref_c to a
higher-precision dtype (e.g., float32), compute absolute errors (abs_err = |c -
ref_c|) and relative errors (rel_err = abs_err / (|ref_c| + eps)), get
max_abs_err and max_rel_err, and assert they are below chosen thresholds (e.g.,
max_abs_err < 1e-2 and max_rel_err < 1e-2); keep the existing cosine_similarity
check and include clear failure messages referencing c, ref_c, sim, max_abs_err,
and max_rel_err so blockscaled_gemm_ref / c mismatches are detected even when
angles match.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant