[Feature] Block-scaled GEMM support for MXFP8 on Blackwell#1945
[Feature] Block-scaled GEMM support for MXFP8 on Blackwell#1945Rachmanino wants to merge 7 commits intotile-ai:mainfrom
Conversation
- Introduced a new Python file for MXFP8 block-scaled GEMM implementation. - Added intrinsic definitions for block-scaled MMA, UTCCP, and scale factor operations in the C++ backend. - Implemented corresponding TileLang functions for block-scaled GEMM and scale factor handling. - Enhanced CUDA code generation to support new block-scaled instructions. - Updated documentation to reflect new functionalities and usage examples. This commit enhances the performance of GEMM operations on the SM100 architecture by leveraging block scaling and optimized memory access patterns.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds end‑to‑end block‑scaled MXFP8 GEMM support for TCGEN5/SM100: new TileLang frontend APIs and intrinsics, TileLang lowering and macro emission for block‑scaled MMA, TCGEN5 builtins and device helpers (UTCCP copy, SF warp‑transpose, sync hooks), CUDA codegen/templates, TMEM copy path, and examples with validation and benchmark. (50 words) Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant PyAPI as blockscaled_gemm (Python API)
participant Emitter as TensorCoreIntrinEmitter
participant MacroGen as tcgen05_macro_generator
participant Codegen as CUDA Codegen
participant Device as GPU Device
User->>PyAPI: call blockscaled_gemm(A,B,C,SFA,SFB,...)
PyAPI->>Emitter: configure layouts, TMEM for C, and SF regions
Emitter->>MacroGen: request tcgen05mma_blockscaled emission (with SF IDs)
MacroGen->>MacroGen: build blockscaled instr desc, resolve TMEM SF pointers
MacroGen->>Codegen: emit ptx_tcgen05_mma_blockscaled_ss / ptx_tcgen05_cp / ptx_tcgen05_sf_warp_transpose / sync calls
Codegen->>Device: generate tcgen05.mma.cta_group / tcgen05.cp / sf_transpose / sync intrinsics
Device->>Device: copy SFs to TMEM, warp‑transpose SFs, perform block‑scaled MMA, write TMEM→DRAM
Device-->>User: results written to C
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Around line 35-37: The current calculations use floor integer division which
can produce 0 or drop tail chunks (e.g., k_iters, sf_load_period,
sfa_num_chunks, sfb_num_chunks) causing divide-by-zero or missing scale-factor
data; change these to use ceiling division (or explicitly clamp sf_load_period =
max(1, ceil_div(...))) and compute sfa_num_chunks/sfb_num_chunks with ceil
division so tail chunks are included, and add a short comment documenting the
required assumptions for sf_granularity_k, block_K, block_M, and block_N to make
the behavior explicit (update the code around k_iters, sf_load_period, and the
sfa_num_chunks/sfb_num_chunks calculations; same fix for the analogous block at
lines 61-64).
In `@src/op/tcgen5_meta.h`:
- Around line 218-265: Summary: a_sf_id and b_sf_id are packed into 2-bit fields
and currently can be silently truncated; validate them before packing in
tl.get_tcgen5_blockscaled_instr_desc. Add explicit checks (e.g., ICHECK(a_sf_id
>= 0 && a_sf_id < 4) and ICHECK(b_sf_id >= 0 && b_sf_id < 4)) near the top of
the function that builds the descriptor (the same scope that defines
encode_mxfp_dtype, set_bits, and computes desc) so invalid values fail fast;
then continue to cast to uint32_t when calling
set_bits(static_cast<uint32_t>(a_sf_id), 29, 2) and
set_bits(static_cast<uint32_t>(b_sf_id), 4, 2).
In `@src/target/codegen_cuda.cc`:
- Around line 2541-2543: The replacer currently emits
"tcgen05mma_blockscaled_ws_ss" when enable_ws is true, but that WS helper
doesn't exist yet; update the registration in replacer.register_rule for
"(tcgen05_name)" to avoid emitting the WS variant until implemented—either
always emit "tcgen05mma_blockscaled_ss" regardless of enable_ws, or add an
explicit guard (e.g., if (enable_ws) { /* TODO: skip or fall back */ } else {
replacer.register_rule(... "tcgen05mma_blockscaled_ss"); }) so the code uses
"tcgen05mma_blockscaled_ss" and never references "tcgen05mma_blockscaled_ws_ss"
until the WS helper/header is added.
In `@src/tl_templates/cuda/tcgen_05.h`:
- Around line 72-82: The in-place warp transpose in tcgen05_sf_warp_transpose
performs a pre-store __syncwarp() but lacks a post-store barrier, so
shared-memory writes may be visible partially to a subsequent UTCCP consumer;
add a second __syncwarp() immediately after the store loop (after the lines that
write smem_ptr[lane * 4 + ...]) to ensure all threads have completed their
stores before returning, or alternatively document/require the caller to perform
this post-store synchronization.
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 720-722: The packing of SF IDs into runtime_instr_desc (using
base_instr_desc | (_sf_a << 29) | (_sf_b << 4)) must mask and validate sf_a_id
and sf_b_id: when creating _sf_a/_sf_b (and before shifting) apply a mask like
(& 0x3) so only the bottom 2 bits are OR-ed, and if sf_a_id or sf_b_id are
constant ints reject/raise if they are outside the allowed range 0..3; update
the logic around the tvm.tir.const creation and the variables _sf_a, _sf_b to
both validate constants and mask non-constants prior to shifting to avoid
spilling into neighboring descriptor fields.
- Around line 599-608: The code currently ignores the returned WS and 2CTA flags
from get_tcgen5_mma_meta (meta -> _enable_ws, enable_2cta) and unconditionally
forces enable_ws=0, which can mis-lower shapes that require WS or 2CTA; update
the logic in tcgen05_macro_generator.py where meta is unpacked
(get_tcgen5_mma_meta and atom_m/atom_n/atom_k/_enable_ws/enable_2cta) to
validate that _enable_ws==0 and enable_2cta==0 before proceeding and otherwise
raise a ValueError (with a clear message including M/N/K/a_dtype/accum_dtype) so
the method fails fast for unsupported WS or 2CTA configurations instead of
silently forcing cta_group::1.
- Around line 681-694: The code currently strips BufferRegion operands to their
underlying buffer data and then sets sfa_offset/sfb_offset to 0, causing
subregions to lose their base offsets; update the SFA_tmem/SFB_tmem handling so
that when the operand is a BufferRegion you capture both the buffer data and its
base offset (e.g. set sfa_data = SFA_tmem.buffer.data and sfa_offset =
SFA_tmem.offset) and when it is a Buffer set sfa_data = SFA_tmem.data and
sfa_offset = 0 (do the analogous change for SFB_tmem -> sfb_data and
sfb_offset); ensure later code uses these sfa_offset/sfb_offset values instead
of hard-coding 0.
In `@tilelang/language/builtin.py`:
- Around line 1040-1048: The code in tcgen05_cp drops the TMEM slice offset by
replacing BufferLoad/BufferRegion destinations with only buffer.data; preserve
the destination's region/indices when computing tmem_ptr so the tmem_col_offset
is applied correctly. Update the branches handling BufferLoad and BufferRegion
to compute and incorporate the region index/offset (from the BufferLoad indices
or BufferRegion.region/offset fields) into tmem_ptr (or augment tmem_col_offset)
before calling tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_utccp"),
smem_ptr, tmem_ptr, tmem_col_offset) so sliced destinations write to the correct
TMEM column. Ensure you reference BufferLoad, BufferRegion, tmem_dst, tmem_ptr
and tmem_col_offset when making the change.
In `@tilelang/language/gemm_op.py`:
- Around line 295-302: Before constructing the TensorCoreIntrinEmitter, validate
that A and B FP8 formats match: read a_dtype = str(A_region.buffer.dtype) and
b_dtype = str(B_region.buffer.dtype) and if they are different FP8 encodings
(e.g., E4M3 vs E5M2) raise/return a clear error instead of proceeding; ensure
you either pass the real b_dtype into TensorCoreIntrinEmitter (b_dtype=b_dtype)
or abort when mixed FP8 formats are detected. This change touches the emitter
construction site (TensorCoreIntrinEmitter) and the code paths that read
A_region.buffer.dtype and B_region.buffer.dtype so the block-scaled descriptor
is not silently built with a single dtype for both operands.
- Around line 327-331: The return annotation "Layout" in
make_blockscaled_gemm_layout is not imported and triggers Pyflakes F821; add an
import for Layout (e.g., from tilelang.layout import Layout) at module level or,
if you want to avoid runtime imports, wrap it in a TYPE_CHECKING block (from
typing import TYPE_CHECKING; if TYPE_CHECKING: from tilelang.layout import
Layout) so the symbol exists for linting while preserving runtime behavior.
In `@tilelang/language/tir/op.py`:
- Around line 1309-1323: The ptx_tcgen05_utccp intrinsic currently names its
first parameter smem_desc but lowering expects a raw shared-memory pointer;
rename the parameter and its docstring from smem_desc to smem_ptr in the
ptx_tcgen05_utccp function (and update any callers/docs) so the frontend
contract matches the lowering, or alternatively add a distinct intrinsic that
explicitly accepts an smem descriptor if you need both forms; ensure the
call_intrin invocation still passes the pointer value and update the docstring
to describe smem_ptr as a raw shared-memory pointer.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c73147b3-3b03-42ed-8ea2-9e21ff0ce44b
📒 Files selected for processing (15)
examples/gemm_sm100/gemm_mxfp8_blockscaled.pysrc/op/builtin.ccsrc/op/builtin.hsrc/op/gemm_py.ccsrc/op/tcgen5_meta.hsrc/target/codegen_cuda.ccsrc/tl_templates/cuda/instruction/tcgen05mma.hsrc/tl_templates/cuda/tcgen_05.htilelang/intrinsics/tcgen05_macro_generator.pytilelang/language/__init__.pytilelang/language/ast/ir.pytilelang/language/builtin.pytilelang/language/gemm_op.pytilelang/language/tir/ir.pytilelang/language/tir/op.py
| k_iters = T.ceildiv(K, block_K) | ||
| # 4 packed E8M0 per uint32 → load every 4 stages | ||
| sf_load_period = sf_granularity_k * 4 // block_K |
There was a problem hiding this comment.
These floor divisions break the example outside the current happy path.
sf_load_period silently floors and can become 0, and sfa_num_chunks/sfb_num_chunks truncate tail chunks. Changing block_K, block_M, or block_N away from the current multiples will either divide by zero or skip scale-factor data entirely.
At least make the assumptions explicit
k_iters = T.ceildiv(K, block_K)
# 4 packed E8M0 per uint32 → load every 4 stages
sf_load_period = sf_granularity_k * 4 // block_K
+ assert sf_load_period > 0 and (sf_granularity_k * 4) % block_K == 0, (
+ "block_K must evenly divide 4 * sf_granularity_k"
+ )
@@
- sfa_num_chunks = block_M // 128 # number of 128-element UTCCP chunks
- sfb_num_chunks = block_N // 128
+ assert block_M % 128 == 0 and block_N % 128 == 0, (
+ "This example currently requires block_M and block_N to be multiples of 128"
+ )
+ sfa_num_chunks = block_M // 128 # number of 128-element UTCCP chunks
+ sfb_num_chunks = block_N // 128Also applies to: 61-64
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 35 - 37, The
current calculations use floor integer division which can produce 0 or drop tail
chunks (e.g., k_iters, sf_load_period, sfa_num_chunks, sfb_num_chunks) causing
divide-by-zero or missing scale-factor data; change these to use ceiling
division (or explicitly clamp sf_load_period = max(1, ceil_div(...))) and
compute sfa_num_chunks/sfb_num_chunks with ceil division so tail chunks are
included, and add a short comment documenting the required assumptions for
sf_granularity_k, block_K, block_M, and block_N to make the behavior explicit
(update the code around k_iters, sf_load_period, and the
sfa_num_chunks/sfb_num_chunks calculations; same fix for the analogous block at
lines 61-64).
| int a_sf_id, int b_sf_id) { | ||
| ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; | ||
| ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; | ||
| ICHECK(scale_in_a == 1 || scale_in_a == -1); | ||
| ICHECK(scale_in_b == 1 || scale_in_b == -1); | ||
|
|
||
| // a_format / b_format for MXF8F6F4: E4M3=0, E5M2=1 | ||
| auto encode_mxfp_dtype = [&](DataType dtype) -> uint32_t { | ||
| if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || | ||
| dtype.is_float8_e4m3()) { | ||
| return 0u; // E4M3 | ||
| } else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) { | ||
| return 1u; // E5M2 | ||
| } | ||
| LOG(FATAL) << "Unsupported dtype for block-scaled descriptor: " << dtype; | ||
| return 0u; | ||
| }; | ||
|
|
||
| auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { | ||
| uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); | ||
| return (value & mask) << start; | ||
| }; | ||
|
|
||
| uint32_t a_format = encode_mxfp_dtype(ab_dtype); | ||
| uint32_t b_format = a_format; | ||
| uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; | ||
| uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; | ||
| uint32_t a_major = a_is_k_major ? 0u : 1u; | ||
| uint32_t b_major = b_is_k_major ? 0u : 1u; | ||
| uint32_t n_dim = static_cast<uint32_t>(atom_n >> 3); | ||
| uint32_t m_dim = static_cast<uint32_t>(atom_m >> 4); | ||
|
|
||
| uint32_t desc = 0; | ||
| desc |= set_bits(0, 0, 2); // sparse_id2 | ||
| desc |= set_bits(0, 2, 1); // sparse_flag | ||
| // bit 3 reserved | ||
| desc |= set_bits(static_cast<uint32_t>(b_sf_id), 4, 2); // b_sf_id | ||
| // bit 6 reserved | ||
| desc |= set_bits(a_format, 7, 3); // a_format | ||
| desc |= set_bits(b_format, 10, 3); // b_format | ||
| desc |= set_bits(a_neg, 13, 1); // a_negate | ||
| desc |= set_bits(b_neg, 14, 1); // b_negate | ||
| desc |= set_bits(a_major, 15, 1); // a_major | ||
| desc |= set_bits(b_major, 16, 1); // b_major | ||
| desc |= set_bits(n_dim, 17, 6); // n_dim | ||
| desc |= set_bits(1, 23, 1); // scale_format = 1 (E8M0) | ||
| desc |= set_bits(m_dim, 24, 5); // m_dim | ||
| desc |= set_bits(static_cast<uint32_t>(a_sf_id), 29, 2); // a_sf_id |
There was a problem hiding this comment.
Validate a_sf_id and b_sf_id before packing them.
Lines 254-265 write both SF IDs through a 2-bit field, so invalid values are silently truncated. For example, -1 becomes 3, and 4 aliases back to 0. Because this helper is exposed via tl.get_tcgen5_blockscaled_instr_desc, bad frontend inputs currently compile into the wrong descriptor instead of failing fast.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/tcgen5_meta.h` around lines 218 - 265, Summary: a_sf_id and b_sf_id
are packed into 2-bit fields and currently can be silently truncated; validate
them before packing in tl.get_tcgen5_blockscaled_instr_desc. Add explicit checks
(e.g., ICHECK(a_sf_id >= 0 && a_sf_id < 4) and ICHECK(b_sf_id >= 0 && b_sf_id <
4)) near the top of the function that builds the descriptor (the same scope that
defines encode_mxfp_dtype, set_bits, and computes desc) so invalid values fail
fast; then continue to cast to uint32_t when calling
set_bits(static_cast<uint32_t>(a_sf_id), 29, 2) and
set_bits(static_cast<uint32_t>(b_sf_id), 4, 2).
| replacer.register_rule( | ||
| "(tcgen05_name)", | ||
| enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss"); |
There was a problem hiding this comment.
Don't emit the WS block-scaled helper until it exists.
This branch can generate tl::tcgen05mma_blockscaled_ws_ss, but the new instruction header in this PR only defines tcgen05mma_blockscaled_ss. Any block-scaled kernel with enable_ws=true will fail during CUDA compilation.
Temporary guard until the WS variant is implemented
- replacer.register_rule(
- "(tcgen05_name)",
- enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss");
+ ICHECK(!enable_ws)
+ << "WS block-scaled tcgen05mma emission is not implemented yet.";
+ replacer.register_rule("(tcgen05_name)", "tcgen05mma_blockscaled_ss");📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| replacer.register_rule( | |
| "(tcgen05_name)", | |
| enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss"); | |
| ICHECK(!enable_ws) | |
| << "WS block-scaled tcgen05mma emission is not implemented yet."; | |
| replacer.register_rule("(tcgen05_name)", "tcgen05mma_blockscaled_ss"); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_cuda.cc` around lines 2541 - 2543, The replacer currently
emits "tcgen05mma_blockscaled_ws_ss" when enable_ws is true, but that WS helper
doesn't exist yet; update the registration in replacer.register_rule for
"(tcgen05_name)" to avoid emitting the WS variant until implemented—either
always emit "tcgen05mma_blockscaled_ss" regardless of enable_ws, or add an
explicit guard (e.g., if (enable_ws) { /* TODO: skip or fall back */ } else {
replacer.register_rule(... "tcgen05mma_blockscaled_ss"); }) so the code uses
"tcgen05mma_blockscaled_ss" and never references "tcgen05mma_blockscaled_ws_ss"
until the WS helper/header is added.
| TL_DEVICE void tcgen05_sf_warp_transpose(uint32_t *smem_ptr) { | ||
| const uint32_t lane = threadIdx.x % 32; | ||
| uint32_t values[4]; | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < 4; ++i) | ||
| values[i] = smem_ptr[(i ^ (lane >> 3)) * 32 + lane]; | ||
| __syncwarp(); | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < 4; ++i) | ||
| smem_ptr[lane * 4 + (i ^ (lane >> 3))] = values[i]; | ||
| } |
There was a problem hiding this comment.
Add a post-store warp sync to the in-place transpose.
Line 78 synchronizes before the writes, which protects the read phase, but nothing orders the shared-memory stores after Line 81 before the next consumer runs. If the elected lane issues UTCCP immediately after this helper returns, it can observe partially written transposed data. Add a second __syncwarp() after the store loop, or make the caller perform one explicitly.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/tl_templates/cuda/tcgen_05.h` around lines 72 - 82, The in-place warp
transpose in tcgen05_sf_warp_transpose performs a pre-store __syncwarp() but
lacks a post-store barrier, so shared-memory writes may be visible partially to
a subsequent UTCCP consumer; add a second __syncwarp() immediately after the
store loop (after the lines that write smem_ptr[lane * 4 + ...]) to ensure all
threads have completed their stores before returning, or alternatively
document/require the caller to perform this post-store synchronization.
| meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) | ||
| if len(meta) != 5: | ||
| raise ValueError( | ||
| f"Unsupported TCGEN5MMA configuration for block-scaled: M={m_dim}, N={n_dim}, " | ||
| f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" | ||
| ) | ||
| atom_m, atom_n, atom_k, _enable_ws, enable_2cta = (int(x) for x in meta) | ||
| # Block-scaled MMA (.block_scale) does NOT support .ws variant per PTX ISA. | ||
| # Force non-ws mode. atom_m=128 is required for cta_group::1 non-ws. | ||
| enable_ws = 0 |
There was a problem hiding this comment.
Reject WS / 2CTA metadata before forcing the non-WS cta_group::1 form.
Lines 605-608 drop _enable_ws and enable_2cta and still emit the only variant this method knows how to generate. That will mis-lower any shape whose selected TCGEN5 meta requires WS or 2CTA support. Fail fast unless the returned meta already matches the supported non-WS cta_group::1 path.
🧰 Tools
🪛 Ruff (0.15.6)
[warning] 605-605: Unpacked variable atom_k is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 605-605: Unpacked variable enable_2cta is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 599 - 608, The
code currently ignores the returned WS and 2CTA flags from get_tcgen5_mma_meta
(meta -> _enable_ws, enable_2cta) and unconditionally forces enable_ws=0, which
can mis-lower shapes that require WS or 2CTA; update the logic in
tcgen05_macro_generator.py where meta is unpacked (get_tcgen5_mma_meta and
atom_m/atom_n/atom_k/_enable_ws/enable_2cta) to validate that _enable_ws==0 and
enable_2cta==0 before proceeding and otherwise raise a ValueError (with a clear
message including M/N/K/a_dtype/accum_dtype) so the method fails fast for
unsupported WS or 2CTA configurations instead of silently forcing cta_group::1.
| _sf_a = tvm.tir.const(sf_a_id, "int32") if isinstance(sf_a_id, int) else sf_a_id | ||
| _sf_b = tvm.tir.const(sf_b_id, "int32") if isinstance(sf_b_id, int) else sf_b_id | ||
| runtime_instr_desc = base_instr_desc | (_sf_a << 29) | (_sf_b << 4) |
There was a problem hiding this comment.
Mask or validate runtime SF IDs before OR-ing them into the descriptor.
Line 722 assumes sf_a_id and sf_b_id are already 2-bit values. If either one is larger than 3, the shift spills into neighboring descriptor fields instead of only updating the SF slots. Apply & 0x3 when packing, and still reject constant IDs outside [0, 3].
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 720 - 722, The
packing of SF IDs into runtime_instr_desc (using base_instr_desc | (_sf_a << 29)
| (_sf_b << 4)) must mask and validate sf_a_id and sf_b_id: when creating
_sf_a/_sf_b (and before shifting) apply a mask like (& 0x3) so only the bottom 2
bits are OR-ed, and if sf_a_id or sf_b_id are constant ints reject/raise if they
are outside the allowed range 0..3; update the logic around the tvm.tir.const
creation and the variables _sf_a, _sf_b to both validate constants and mask
non-constants prior to shifting to avoid spilling into neighboring descriptor
fields.
| if isinstance(tmem_dst, (tir.Buffer,)): | ||
| tmem_ptr = tmem_dst.data | ||
| elif isinstance(tmem_dst, BufferLoad): | ||
| tmem_ptr = tmem_dst.buffer.data | ||
| elif isinstance(tmem_dst, BufferRegion): | ||
| tmem_ptr = tmem_dst.buffer.data | ||
| else: | ||
| tmem_ptr = tmem_dst | ||
| return tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_utccp"), smem_ptr, tmem_ptr, tmem_col_offset) |
There was a problem hiding this comment.
Don't discard the TMEM slice offset in tcgen05_cp.
For BufferLoad/BufferRegion destinations this helper drops the region indices and forwards only buffer.data. The codegen path then applies only tmem_col_offset, so a sliced destination still writes from TMEM column 0 unless the caller manually duplicates the offset.
Safe short-term fix
- if isinstance(tmem_dst, (tir.Buffer,)):
+ if isinstance(tmem_dst, tir.Buffer):
tmem_ptr = tmem_dst.data
- elif isinstance(tmem_dst, BufferLoad):
- tmem_ptr = tmem_dst.buffer.data
- elif isinstance(tmem_dst, BufferRegion):
- tmem_ptr = tmem_dst.buffer.data
+ elif isinstance(tmem_dst, (BufferLoad, BufferRegion)):
+ raise ValueError(
+ "tcgen05_cp currently expects the base TMEM buffer; "
+ "use tmem_col_offset to select destination columns."
+ )
else:
tmem_ptr = tmem_dst🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/builtin.py` around lines 1040 - 1048, The code in
tcgen05_cp drops the TMEM slice offset by replacing BufferLoad/BufferRegion
destinations with only buffer.data; preserve the destination's region/indices
when computing tmem_ptr so the tmem_col_offset is applied correctly. Update the
branches handling BufferLoad and BufferRegion to compute and incorporate the
region index/offset (from the BufferLoad indices or BufferRegion.region/offset
fields) into tmem_ptr (or augment tmem_col_offset) before calling
tir.call_intrin("void", tir.op.Op.get("tl.ptx_tcgen05_utccp"), smem_ptr,
tmem_ptr, tmem_col_offset) so sliced destinations write to the correct TMEM
column. Ensure you reference BufferLoad, BufferRegion, tmem_dst, tmem_ptr and
tmem_col_offset when making the change.
| a_dtype = str(A_region.buffer.dtype) | ||
| accum_dtype = str(C_region.buffer.dtype) | ||
|
|
||
| # Create intrinsic emitter — for block-scaled, TCGEN5 always uses 1 warp group | ||
| emitter = TensorCoreIntrinEmitter( | ||
| a_dtype=a_dtype, | ||
| b_dtype=a_dtype, | ||
| accum_dtype=accum_dtype, |
There was a problem hiding this comment.
Validate A/B dtype compatibility before building the emitter.
This hard-codes b_dtype=a_dtype, so a mixed E4M3/E5M2 call will silently encode B as A's format. The block-scaled descriptor path currently mirrors a single FP8 dtype into both a_format and b_format, so this should be rejected explicitly instead of generating the wrong descriptor.
Minimal defensive fix
a_dtype = str(A_region.buffer.dtype)
+ b_dtype = str(B_region.buffer.dtype)
accum_dtype = str(C_region.buffer.dtype)
+ if b_dtype != a_dtype:
+ raise ValueError(
+ "blockscaled_gemm currently requires A and B to use the same FP8 dtype."
+ )
# Create intrinsic emitter — for block-scaled, TCGEN5 always uses 1 warp group
emitter = TensorCoreIntrinEmitter(
a_dtype=a_dtype,
- b_dtype=a_dtype,
+ b_dtype=b_dtype,
accum_dtype=accum_dtype,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/gemm_op.py` around lines 295 - 302, Before constructing the
TensorCoreIntrinEmitter, validate that A and B FP8 formats match: read a_dtype =
str(A_region.buffer.dtype) and b_dtype = str(B_region.buffer.dtype) and if they
are different FP8 encodings (e.g., E4M3 vs E5M2) raise/return a clear error
instead of proceeding; ensure you either pass the real b_dtype into
TensorCoreIntrinEmitter (b_dtype=b_dtype) or abort when mixed FP8 formats are
detected. This change touches the emitter construction site
(TensorCoreIntrinEmitter) and the code paths that read A_region.buffer.dtype and
B_region.buffer.dtype so the block-scaled descriptor is not silently built with
a single dtype for both operands.
| def make_blockscaled_gemm_layout( | ||
| C: BufferLikeType, | ||
| A: BufferLikeType, | ||
| transpose_A: bool = False, | ||
| ) -> "Layout": |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file structure and locate the relevant lines
fd gemm_op.pyRepository: tile-ai/tilelang
Length of output: 88
🏁 Script executed:
# Check the imports and the function in question
cat -n tilelang/language/gemm_op.py | head -40Repository: tile-ai/tilelang
Length of output: 1407
🏁 Script executed:
# Look at the lines around 327-331
sed -n '320,340p' tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 810
🏁 Script executed:
# Search for all imports related to Layout in the file
rg "import.*Layout|from.*Layout" tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Also check if Layout is imported anywhere or if it's being used as a string
rg "Layout" tilelang/language/gemm_op.py | head -20Repository: tile-ai/tilelang
Length of output: 120
🏁 Script executed:
# Find where Layout is defined
fd -t f "\.py$" | xargs rg "^class Layout" | head -5Repository: tile-ai/tilelang
Length of output: 105
🏁 Script executed:
# Check what's in tilelang.layout module
fd layout.pyRepository: tile-ai/tilelang
Length of output: 381
🏁 Script executed:
# Check if there's a TYPE_CHECKING import block in the file
rg "TYPE_CHECKING|tilelang.layout" tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 122
🏁 Script executed:
# Also check the full imports section to be thorough
sed -n '1,30p' tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 851
🏁 Script executed:
# Search for the full tilelang.layout import line
rg -B2 -A2 "from tilelang.layout import" tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 185
🏁 Script executed:
# Check if Layout is imported or if we need to verify the entire imports section
cat -n tilelang/language/gemm_op.py | sed -n '1,100p'Repository: tile-ai/tilelang
Length of output: 4147
🏁 Script executed:
# Look for TYPE_CHECKING imports more carefully
sed -n '1,500p' tilelang/language/gemm_op.py | grep -B5 -A5 "TYPE_CHECKING\|tilelang.layout"Repository: tile-ai/tilelang
Length of output: 429
🏁 Script executed:
# Check around the context where the tilelang.layout import appears
sed -n '180,220p' tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 1418
🏁 Script executed:
# Let's see the function around line 327-331 more completely
sed -n '325,345p' tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 758
🏁 Script executed:
# Check if there's a pyproject.toml or .flake8 to understand the Ruff config
fd -t f "pyproject.toml|ruff.toml|.flake8"Repository: tile-ai/tilelang
Length of output: 84
🏁 Script executed:
# Check the ruff configuration
cat ruff.tomlRepository: tile-ai/tilelang
Length of output: 1137
🏁 Script executed:
# Also check pyproject.toml for any ruff config
cat pyproject.toml | grep -A 50 "tool.ruff"Repository: tile-ai/tilelang
Length of output: 1616
🏁 Script executed:
# Let's verify if Layout is actually being used elsewhere in the file to understand the context better
rg "Layout" tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 120
🏁 Script executed:
# Let's check what the function actually returns to understand the full context
sed -n '327,380p' tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 1534
🏁 Script executed:
# Verify if Layout is referenced anywhere else in the function body
sed -n '325,380p' tilelang/language/gemm_op.py | grep -i "layout"Repository: tile-ai/tilelang
Length of output: 368
🏁 Script executed:
# Let's check what make_mma_store_layout actually returns
fd -t f "\.py$" -path "*tcgen05*" | head -5Repository: tile-ai/tilelang
Length of output: 228
🏁 Script executed:
# Check the TensorCoreIntrinEmitter class to see what make_mma_store_layout returns
rg "def make_mma_store_layout" -A 5Repository: tile-ai/tilelang
Length of output: 2844
🏁 Script executed:
# Check imports in tcgen05_macro_generator.py
head -50 tilelang/intrinsics/tcgen05_macro_generator.py | grep -E "^import|^from"Repository: tile-ai/tilelang
Length of output: 487
🏁 Script executed:
# Check if Layout is imported there
rg "Layout" tilelang/intrinsics/tcgen05_macro_generator.py | head -10Repository: tile-ai/tilelang
Length of output: 576
🏁 Script executed:
# Check if TYPE_CHECKING is used anywhere in the file
rg "TYPE_CHECKING" tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Final verification: check the exact Ruff behavior with F821 and future annotations
rg "from __future__ import annotations" tilelang/language/gemm_op.pyRepository: tile-ai/tilelang
Length of output: 94
🏁 Script executed:
# Let's do a final check to see how Layout is used across the codebase to understand the import pattern
rg "from tilelang.layout import Layout" | head -10Repository: tile-ai/tilelang
Length of output: 474
Import Layout or use a TYPE_CHECKING import to avoid the F821 linting error.
The return type annotation references Layout which is not imported in this module. Although the string annotation from from __future__ import annotations prevents runtime errors, Ruff's Pyflakes (F821) check will flag the undefined name. You can fix this by adding from tilelang.layout import Layout at the module level or using a TYPE_CHECKING block if the import is only needed for the annotation.
🧰 Tools
🪛 Ruff (0.15.6)
[error] 331-331: Undefined name Layout
(F821)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/gemm_op.py` around lines 327 - 331, The return annotation
"Layout" in make_blockscaled_gemm_layout is not imported and triggers Pyflakes
F821; add an import for Layout (e.g., from tilelang.layout import Layout) at
module level or, if you want to avoid runtime imports, wrap it in a
TYPE_CHECKING block (from typing import TYPE_CHECKING; if TYPE_CHECKING: from
tilelang.layout import Layout) so the symbol exists for linting while preserving
runtime behavior.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (3)
tilelang/language/gemm_op.py (2)
330-334:⚠️ Potential issue | 🟡 MinorImport
Layoutfor the return annotation.Line 334 still references
Layoutwithout bringing it into the module namespace, so Ruff keeps reporting F821 here. ATYPE_CHECKINGimport is enough if you want to avoid a runtime import.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/gemm_op.py` around lines 330 - 334, The function make_blockscaled_gemm_layout uses the return annotation "Layout" but Layout is not imported, causing F821; add an import for Layout into this module (preferably inside an if TYPE_CHECKING: block) so the name is available for type checking without a runtime dependency; reference the Layout symbol and the make_blockscaled_gemm_layout function when adding the import.
295-302:⚠️ Potential issue | 🟠 MajorReject mixed FP8 operand formats here instead of mirroring
AontoB.
b_dtype=a_dtypestill silently encodesBas whateverAuses. A mixed E4M3/E5M2 call will build the wrong block-scaled descriptor instead of failing fast.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/gemm_op.py` around lines 295 - 302, Currently the code sets b_dtype implicitly to a_dtype when creating TensorCoreIntrinEmitter; instead, read B's actual dtype (use B_region.buffer.dtype) into b_dtype and add a validation: if either a_dtype or b_dtype is an FP8 format (e.g., "e4m3" or "e5m2") and a_dtype != b_dtype, raise an explicit error (ValueError/Assertion) rejecting mixed FP8 operand formats; keep TensorCoreIntrinEmitter(a_dtype=a_dtype, b_dtype=b_dtype, accum_dtype=accum_dtype) after this check so mixed FP8s fail fast rather than silently mirroring A onto B.examples/gemm_sm100/gemm_mxfp8_blockscaled.py (1)
36-38:⚠️ Potential issue | 🟠 MajorMake the example’s tile-size assumptions explicit.
sf_load_periodstill depends on floor division and can collapse to 0 or skip the last partial pack, whilesfa_num_chunks/sfb_num_chunksstill drop any tail below 128 elements. As written, this schedule only holds whenblock_Kevenly divides4 * sf_granularity_kandblock_M/block_Nare multiples of 128.Also applies to: 62-65
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 36 - 38, The schedule assumes tile sizes divide certain pack/granularity values but currently uses floor division and drops tails; fix by making assumptions explicit and handling tails: compute sf_load_period = T.ceildiv(sf_granularity_k * 4, block_K) (not floor) and ensure it is at least 1, compute k_iters = T.ceildiv(K, block_K) (already present) and use ceildiv for sfa_num_chunks/sfb_num_chunks (e.g., T.ceildiv(block_M, 128) and T.ceildiv(block_N, 128)) so partial 128-element tails are included, or add explicit assertions that block_K evenly divides 4*sf_granularity_k and block_M/block_N are multiples of 128 (raise helpful error messages referencing sf_load_period, sf_granularity_k, block_K, sfa_num_chunks, sfb_num_chunks, block_M, block_N) if you prefer to keep the simpler schedule.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Around line 142-143: sf_a_id and sf_b_id are computed from sf_load_period but
must be derived from the scale-factor granularity (sf_granularity_k) so SF IDs
stay within 0..3 when a scale factor spans multiple block_K tiles; change the
calculation of sf_a_id and sf_b_id to use the granularity-based index (e.g.,
floor(k * block_K / sf_granularity_k) mod sf_load_period or equivalent) so each
packed E8M0 scale factor is reused across the correct number of K tiles; update
the assignments near where sf_a_id and sf_b_id are set and ensure the logic
aligns with tcgen05 expectations and uses sf_granularity_k instead of
sf_load_period directly.
In `@tilelang/language/gemm_op.py`:
- Around line 240-241: Validate that the scalar-field IDs are within the
documented 0–3 range before packing them into the two-bit fields: add explicit
checks for sf_a_id and sf_b_id (the places where the tcgen05 / block-scaled
descriptor is built) and raise a clear exception (e.g., ValueError) if either is
outside 0..3; apply the same guarded checks at the other occurrence referenced
(the second place handling sf_a_id/sf_b_id around the later block) so invalid
constants fail fast instead of producing an incorrect descriptor.
- Around line 284-293: Validate that A_shape, B_shape and C_shape are consistent
before creating the TensorCoreIntrinEmitter: compute M = int(C_shape[0]), N =
int(C_shape[1]) and K_expected = int(A_shape[-2] if transpose_A else
A_shape[-1]); also compute K_from_B = int(B_shape[-1] if (not transpose_B) else
B_shape[-2]) (or the symmetric form used in this codepath) and verify K_expected
== K_from_B, and that A provides M rows and B provides N columns consistent with
C (e.g. A_shape[0 or -2] matches M depending on transpose_A, and B_shape[1 or
-1] matches N depending on transpose_B); if any check fails, raise an
informative exception (or assert) before instantiating TensorCoreIntrinEmitter
so the emitter is configured only with matching tile dimensions.
---
Duplicate comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Around line 36-38: The schedule assumes tile sizes divide certain
pack/granularity values but currently uses floor division and drops tails; fix
by making assumptions explicit and handling tails: compute sf_load_period =
T.ceildiv(sf_granularity_k * 4, block_K) (not floor) and ensure it is at least
1, compute k_iters = T.ceildiv(K, block_K) (already present) and use ceildiv for
sfa_num_chunks/sfb_num_chunks (e.g., T.ceildiv(block_M, 128) and
T.ceildiv(block_N, 128)) so partial 128-element tails are included, or add
explicit assertions that block_K evenly divides 4*sf_granularity_k and
block_M/block_N are multiples of 128 (raise helpful error messages referencing
sf_load_period, sf_granularity_k, block_K, sfa_num_chunks, sfb_num_chunks,
block_M, block_N) if you prefer to keep the simpler schedule.
In `@tilelang/language/gemm_op.py`:
- Around line 330-334: The function make_blockscaled_gemm_layout uses the return
annotation "Layout" but Layout is not imported, causing F821; add an import for
Layout into this module (preferably inside an if TYPE_CHECKING: block) so the
name is available for type checking without a runtime dependency; reference the
Layout symbol and the make_blockscaled_gemm_layout function when adding the
import.
- Around line 295-302: Currently the code sets b_dtype implicitly to a_dtype
when creating TensorCoreIntrinEmitter; instead, read B's actual dtype (use
B_region.buffer.dtype) into b_dtype and add a validation: if either a_dtype or
b_dtype is an FP8 format (e.g., "e4m3" or "e5m2") and a_dtype != b_dtype, raise
an explicit error (ValueError/Assertion) rejecting mixed FP8 operand formats;
keep TensorCoreIntrinEmitter(a_dtype=a_dtype, b_dtype=b_dtype,
accum_dtype=accum_dtype) after this check so mixed FP8s fail fast rather than
silently mirroring A onto B.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b3f2d616-e7be-4c47-8e70-e3fcfa9cfc4f
📒 Files selected for processing (2)
examples/gemm_sm100/gemm_mxfp8_blockscaled.pytilelang/language/gemm_op.py
| sf_a_id=k % sf_load_period, | ||
| sf_b_id=k % sf_load_period, |
There was a problem hiding this comment.
Derive sf_*_id from scale-factor granularity, not the pack reload period.
The tcgen05 block-scaled path only accepts SF IDs 0-3, but k % sf_load_period only works when one block_K tile maps to one scale factor. If block_K < sf_granularity_k, this starts emitting 0..7, 0..15, etc. instead of reusing each packed E8M0 value for multiple K tiles.
One safe way to keep this parameterized
- sf_load_period = sf_granularity_k * 4 // block_K
+ assert sf_granularity_k % block_K == 0, (
+ "This example currently requires block_K to evenly divide sf_granularity_k"
+ )
+ sf_tiles_per_value = sf_granularity_k // block_K
+ sf_load_period = sf_tiles_per_value * 4
@@
- sf_a_id=k % sf_load_period,
- sf_b_id=k % sf_load_period,
+ sf_a_id=(k // sf_tiles_per_value) % 4,
+ sf_b_id=(k // sf_tiles_per_value) % 4,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 142 - 143,
sf_a_id and sf_b_id are computed from sf_load_period but must be derived from
the scale-factor granularity (sf_granularity_k) so SF IDs stay within 0..3 when
a scale factor spans multiple block_K tiles; change the calculation of sf_a_id
and sf_b_id to use the granularity-based index (e.g., floor(k * block_K /
sf_granularity_k) mod sf_load_period or equivalent) so each packed E8M0 scale
factor is reused across the correct number of K tiles; update the assignments
near where sf_a_id and sf_b_id are set and ensure the logic aligns with tcgen05
expectations and uses sf_granularity_k instead of sf_load_period directly.
| sf_a_id: int = 0, | ||
| sf_b_id: int = 0, |
There was a problem hiding this comment.
Validate constant sf_a_id / sf_b_id before passing them through.
These IDs are documented as 0-3, and the tcgen05 block-scaled descriptor packs them into two-bit fields. A constant value outside that range will silently produce the wrong instruction descriptor here instead of failing fast.
Small guard for the public API
SFA_tmem = legalize(SFA_tmem)
SFB_tmem = legalize(SFB_tmem)
mbar = legalize(mbar) if mbar is not None else None
+ if isinstance(sf_a_id, int) and not 0 <= sf_a_id <= 3:
+ raise ValueError(f"sf_a_id must be in [0, 3], got {sf_a_id}")
+ if isinstance(sf_b_id, int) and not 0 <= sf_b_id <= 3:
+ raise ValueError(f"sf_b_id must be in [0, 3], got {sf_b_id}")Also applies to: 324-326
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/gemm_op.py` around lines 240 - 241, Validate that the
scalar-field IDs are within the documented 0–3 range before packing them into
the two-bit fields: add explicit checks for sf_a_id and sf_b_id (the places
where the tcgen05 / block-scaled descriptor is built) and raise a clear
exception (e.g., ValueError) if either is outside 0..3; apply the same guarded
checks at the other occurrence referenced (the second place handling
sf_a_id/sf_b_id around the later block) so invalid constants fail fast instead
of producing an incorrect descriptor.
There was a problem hiding this comment.
🧹 Nitpick comments (5)
src/op/builtin.h (1)
352-365: Minor formatting inconsistency: extra leading space on line 365.Line 365 has an extra leading space before
TVM_DLLcompared to lines 355 and 360. This doesn't affect functionality but breaks consistency with surrounding declarations.-/*! - * \brief tvm intrinsic for scale factor warp transpose in shared memory. - */ - TVM_DLL const Op &ptx_tcgen05_sf_warp_transpose(); +/*! + * \brief tvm intrinsic for scale factor warp transpose in shared memory. + */ +TVM_DLL const Op &ptx_tcgen05_sf_warp_transpose();🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/builtin.h` around lines 352 - 365, Remove the extra leading space before the declaration of the ptx_tcgen05_sf_warp_transpose intrinsic so it aligns with the other TVM_DLL declarations (make the line start with "TVM_DLL const Op &ptx_tcgen05_sf_warp_transpose();"); this fixes the formatting inconsistency in the block containing ptx_tcgen05_mma_blockscaled_ss and ptx_tcgen05_cp.src/op/builtin.cc (1)
212-226: Minor formatting inconsistencies.The registrations are functionally correct, but there are minor formatting issues:
- Line 225: Extra space before
Integercompared to other registrations.- Line 226: Trailing whitespace.
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_sf_warp_transpose) .set_num_inputs(1) .set_attr<TCallEffectKind>("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - + Integer(CallEffectKind::kOpaque)); +🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/builtin.cc` around lines 212 - 226, The three TIR builtin registrations (ptx_tcgen05_mma_blockscaled_ss, ptx_tcgen05_cp, ptx_tcgen05_sf_warp_transpose) have minor formatting inconsistencies: remove the extra space before Integer in the .set_attr call for ptx_tcgen05_sf_warp_transpose so it matches the other registrations and delete the trailing whitespace at the end of that line; ensure the .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); call formatting is consistent across all three builtins.src/op/gemm_py.cc (1)
90-102: Partial scale-factor specification may cause subtle issues.The parsing allows specifying only some of the block-scaled fields (e.g.,
sfaRegion_withoutsfbRegion_). Sinceis_blockscaledingemm_base.pychecks that both regions are present, partial specification silently falls back to non-blockscaled mode rather than raising an error.Consider either:
- Requiring all four fields together when any is present, or
- Documenting that partial specification is valid and falls back to standard GEMM.
// Block-scaled GEMM: optional SFA, SFB regions and scale factor IDs if (args.size() > 19) { node->sfaRegion_ = NormalizeToBufferRegion(args[19]); } if (args.size() > 20) { node->sfbRegion_ = NormalizeToBufferRegion(args[20]); } + // Validate: either both SF regions are set, or neither + if (node->sfaRegion_.defined() != node->sfbRegion_.defined()) { + ICHECK(false) << "Block-scaled GEMM requires both sfaRegion and sfbRegion"; + } if (args.size() > 21) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm_py.cc` around lines 90 - 102, The parser in gemm_py.cc allows specifying sfaRegion_, sfbRegion_, sfAId_, sfBId_ independently which can produce a partial block-scaled specification that gemm_base.py's is_blockscaled (which expects both regions) will treat as not blockscaled; update the parsing logic so that when any of the four related args (sfaRegion_, sfbRegion_, sfAId_, sfBId_) is present you require all four — e.g., after reading args[19..22] validate presence of the other fields and raise an error or throw a runtime/parse exception if any are missing; locate the assignments to sfaRegion_, sfbRegion_, sfAId_, sfBId_ and the use of NormalizeToBufferRegion to implement this all-or-none validation, referencing is_blockscaled in gemm_base.py as the behavioral expectation.src/op/gemm_py.h (1)
39-41: Consider consistent naming:sfaRegion_vssfAId_.The naming uses different casing patterns:
sfaRegion_,sfbRegion_(all lowercase "sfa"/"sfb")sfAId_,sfBId_(camelCase "sfA"/"sfB")This propagates to the reflection bindings (
sfaRegionvssfAId). For API consistency, consider aligning to one pattern (e.g.,sfARegion_/sfBRegion_orsfaId_/sfbId_).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm_py.h` around lines 39 - 41, The field names for block-scaled GEMM are inconsistent: you have sfaRegion_ / sfbRegion_ but sfAId_ / sfBId_ (and the reflection bindings use mixed forms like sfaRegion vs sfAId); pick one consistent casing (e.g., sfARegion_ / sfBRegion_ and sfAId_ / sfBId_ or sfaRegion_ / sfbRegion_ and sfaId_ / sfbId_) and rename the symbols accordingly throughout the class and its reflection bindings so declarations (sfaRegion_, sfbRegion_, sfAId_, sfBId_), any getters/setters, and reflection names (sfaRegion, sfAId, etc.) all match the chosen pattern; ensure you update all usages and serialization/reflection code that references these identifiers (search for sfaRegion_, sfbRegion_, sfAId_, sfBId_, sfaRegion, sfbRegion, sfAId, sfBId).tilelang/language/gemm_op.py (1)
320-321: Minor: Remove unusedmbar_argvariable.Line 321 assigns
mbar_arg = mbarbut line 351 passesmbardirectly, makingmbar_argunused.Cleanup
assert mbar is not None, "mbar is required for blockscaled_gemm" - mbar_arg = mbar🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/gemm_op.py` around lines 320 - 321, The assignment "mbar_arg = mbar" is redundant because the code later passes "mbar" directly; remove the unused variable by deleting the "mbar_arg = mbar" line (keeping the assertion assert mbar is not None, "mbar is required for blockscaled_gemm") and verify no other references to mbar_arg exist in the surrounding functions or the blockscaled_gemm call sites.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/op/builtin.cc`:
- Around line 212-226: The three TIR builtin registrations
(ptx_tcgen05_mma_blockscaled_ss, ptx_tcgen05_cp, ptx_tcgen05_sf_warp_transpose)
have minor formatting inconsistencies: remove the extra space before Integer in
the .set_attr call for ptx_tcgen05_sf_warp_transpose so it matches the other
registrations and delete the trailing whitespace at the end of that line; ensure
the .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); call formatting is consistent across all
three builtins.
In `@src/op/builtin.h`:
- Around line 352-365: Remove the extra leading space before the declaration of
the ptx_tcgen05_sf_warp_transpose intrinsic so it aligns with the other TVM_DLL
declarations (make the line start with "TVM_DLL const Op
&ptx_tcgen05_sf_warp_transpose();"); this fixes the formatting inconsistency in
the block containing ptx_tcgen05_mma_blockscaled_ss and ptx_tcgen05_cp.
In `@src/op/gemm_py.cc`:
- Around line 90-102: The parser in gemm_py.cc allows specifying sfaRegion_,
sfbRegion_, sfAId_, sfBId_ independently which can produce a partial
block-scaled specification that gemm_base.py's is_blockscaled (which expects
both regions) will treat as not blockscaled; update the parsing logic so that
when any of the four related args (sfaRegion_, sfbRegion_, sfAId_, sfBId_) is
present you require all four — e.g., after reading args[19..22] validate
presence of the other fields and raise an error or throw a runtime/parse
exception if any are missing; locate the assignments to sfaRegion_, sfbRegion_,
sfAId_, sfBId_ and the use of NormalizeToBufferRegion to implement this
all-or-none validation, referencing is_blockscaled in gemm_base.py as the
behavioral expectation.
In `@src/op/gemm_py.h`:
- Around line 39-41: The field names for block-scaled GEMM are inconsistent: you
have sfaRegion_ / sfbRegion_ but sfAId_ / sfBId_ (and the reflection bindings
use mixed forms like sfaRegion vs sfAId); pick one consistent casing (e.g.,
sfARegion_ / sfBRegion_ and sfAId_ / sfBId_ or sfaRegion_ / sfbRegion_ and
sfaId_ / sfbId_) and rename the symbols accordingly throughout the class and its
reflection bindings so declarations (sfaRegion_, sfbRegion_, sfAId_, sfBId_),
any getters/setters, and reflection names (sfaRegion, sfAId, etc.) all match the
chosen pattern; ensure you update all usages and serialization/reflection code
that references these identifiers (search for sfaRegion_, sfbRegion_, sfAId_,
sfBId_, sfaRegion, sfbRegion, sfAId, sfBId).
In `@tilelang/language/gemm_op.py`:
- Around line 320-321: The assignment "mbar_arg = mbar" is redundant because the
code later passes "mbar" directly; remove the unused variable by deleting the
"mbar_arg = mbar" line (keeping the assertion assert mbar is not None, "mbar is
required for blockscaled_gemm") and verify no other references to mbar_arg exist
in the surrounding functions or the blockscaled_gemm call sites.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f8addcf4-750a-4654-8fd8-ab0a3cd86b35
📒 Files selected for processing (14)
examples/gemm_sm100/gemm_mxfp8_blockscaled.pysrc/op/builtin.ccsrc/op/builtin.hsrc/op/gemm_py.ccsrc/op/gemm_py.hsrc/target/codegen_cuda.cctilelang/language/ast/ir.pytilelang/language/builtin.pytilelang/language/gemm_op.pytilelang/language/tir/ir.pytilelang/language/tir/op.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_base.pytilelang/tileop/gemm/gemm_tcgen05.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/language/tir/ir.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
examples/gemm_sm100/gemm_mxfp8_blockscaled.py (2)
37-37:⚠️ Potential issue | 🟠 MajorDerive SF IDs from granularity, not reload-period index.
At Line 128 and Line 129,
k % sf_load_periodcan exceed3when one SF value spans multiple K tiles (e.g.,block_K < sf_granularity_k). Forblockscaled_gemm, SF IDs need to remain in0..3.Safer parameterized mapping
- sf_load_period = sf_granularity_k * 4 // block_K + assert sf_granularity_k % block_K == 0, ( + "This kernel currently requires block_K to evenly divide sf_granularity_k" + ) + sf_tiles_per_value = sf_granularity_k // block_K + sf_load_period = sf_tiles_per_value * 4 @@ - sf_a_id=k % sf_load_period, - sf_b_id=k % sf_load_period, + sf_a_id=(k // sf_tiles_per_value) % 4, + sf_b_id=(k // sf_tiles_per_value) % 4,Use this read-only check to confirm the API constraint and call-site mismatch:
#!/bin/bash set -euo pipefail # Verify blockscaled_gemm API docs/args for SF ID constraints in TileLang rg -n -C3 'sf_a_id|sf_b_id|Scale factor ID|0-3' tilelang/language/gemm_op.py # Verify current example call-site derivation rg -n -C3 'sf_load_period|sf_a_id=|sf_b_id=' examples/gemm_sm100/gemm_mxfp8_blockscaled.pyAlso applies to: 119-129
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` at line 37, The current derivation uses sf_load_period and k % sf_load_period which can yield values >3; instead compute SF IDs from granularity: derive the SF tile index as (k // block_K) // sf_granularity_k and then take modulo 4 to get an ID in 0..3, e.g. sf_tile_index = ((k // block_K) // sf_granularity_k) and sf_id = sf_tile_index % 4, and use that for sf_a_id and sf_b_id; update the code that computes sf_load_period and the places assigning sf_a_id/sf_b_id to use this granularity-based mapping (referencing sf_load_period, sf_granularity_k, block_K, sf_a_id, sf_b_id).
35-37:⚠️ Potential issue | 🟠 MajorGuard the current “happy-path only” assumptions explicitly.
Line 37 can yield
sf_load_period == 0, and Lines 61-62 floor-truncate UTCCP chunk counts. That can trigger divide-by-zero at Lines 93/111 or silently skip SF data when block sizes change.Suggested guardrails
k_iters = T.ceildiv(K, block_K) # 4 packed E8M0 per uint32 → load every 4 stages - sf_load_period = sf_granularity_k * 4 // block_K + assert block_K > 0 and sf_granularity_k > 0 + assert (sf_granularity_k * 4) % block_K == 0, ( + "block_K must evenly divide 4 * sf_granularity_k" + ) + sf_load_period = sf_granularity_k * 4 // block_K @@ - sfa_num_chunks = block_M // 128 # number of 128-element UTCCP chunks - sfb_num_chunks = block_N // 128 + assert block_M % 128 == 0 and block_N % 128 == 0, ( + "UTCCP path currently requires block_M and block_N to be multiples of 128" + ) + sfa_num_chunks = block_M // 128 + sfb_num_chunks = block_N // 128Also applies to: 61-64, 93-94, 111-111
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` around lines 35 - 37, The computation for sf_load_period and subsequent UTCCP chunk counts assumes non-zero divisors and can produce sf_load_period == 0 (or zero chunk counts) leading to divide-by-zero or dropped SF data; update the calculations that produce sf_load_period (currently sf_load_period = sf_granularity_k * 4 // block_K) and any derived UTCCP chunk counts to defensively enforce a minimum of 1 (e.g., wrap with max(1, ...)), and add explicit guards or assertions before using these values as divisors so code using k_iters, sf_load_period, and the UTCCP chunk variables cannot divide by zero or silently skip SF processing. Ensure you modify all uses referenced (sf_load_period, sf_granularity_k, block_K and the chunk count calculations) so they never become zero.
🧹 Nitpick comments (2)
examples/gemm_sm100/gemm_mxfp8_blockscaled.py (1)
235-235: Avoid printing full 8192×8192 tensors in the example path.Line 235 can dominate runtime/log volume and make benchmark output hard to use. Prefer shape/stats or gated debug printing.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py` at line 235, The example currently prints full tensors c and ref_c (print(f"{c=}, {ref_c=}")), which will log huge 8192×8192 arrays; replace that line to avoid dumping full tensors by printing compact diagnostics instead (e.g., shapes and summary statistics like c.shape, c.dtype, c.mean().item(), c.std().item(), and a max absolute difference between c and ref_c) or wrap the full-tensor print behind a verbose/debug flag so normal runs only emit lightweight summaries; update the print call in gemm_mxfp8_blockscaled.py at the location of variables c and ref_c accordingly.src/op/copy.cc (1)
1319-1326: Consider applying buffer remapping for dynamic shared memory.The code uses
src->datadirectly without checkingT.buffer_remap. For dynamic shared memory (shared.dyn), the buffer may need remapping similar to other copy paths (e.g.,LowerBulkCopyat line 1684 andLowerLDSMCopyat line 1102).♻️ Proposed fix
+ // Apply buffer remapping for dynamic shared memory + Buffer smem_buf = src; + if (T.buffer_remap.count(src)) { + smem_buf = T.buffer_remap.at(src); + } + // SMEM base access_ptr - PrimExpr ptype = tir::TypeAnnotation(src->dtype); + PrimExpr ptype = tir::TypeAnnotation(smem_buf->dtype); auto make_smem_ptr = [&](PrimExpr elem_offset) { return Call(DataType::Handle(), builtin::tvm_access_ptr(), - {ptype, src->data, elem_offset, + {ptype, smem_buf->data, elem_offset, IntImm(DataType::Int(32), ELEMENTS_PER_CP), IntImm(DataType::Int(32), 1 /*read*/)}); };🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 1319 - 1326, The shared-memory pointer builder make_smem_ptr currently uses src->data directly, which fails to apply buffer remapping for dynamic shared memory; modify the lambda (and the surrounding logic where it captures src) to check for T.buffer_remap (the same remapping used in LowerBulkCopy/LowerLDSMCopy) and use the remapped buffer handle instead of src->data when buffer_remap is present (i.e., for shared.dyn), so the Call to tvm_access_ptr receives the remapped buffer pointer rather than the original src->data.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/copy.cc`:
- Around line 1341-1344: Guard against an empty stmts vector before accessing
stmts[0]: check if stmts.empty() (derived from total_elements/num_calls) and
either return an appropriate no-op/empty Stmt early or set body to an explicit
empty statement instead of dereferencing stmts[0]; update the block around
variables stmts, body and the SeqStmt construction so the loop only runs when
stmts is non-empty to avoid undefined behavior.
---
Duplicate comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Line 37: The current derivation uses sf_load_period and k % sf_load_period
which can yield values >3; instead compute SF IDs from granularity: derive the
SF tile index as (k // block_K) // sf_granularity_k and then take modulo 4 to
get an ID in 0..3, e.g. sf_tile_index = ((k // block_K) // sf_granularity_k) and
sf_id = sf_tile_index % 4, and use that for sf_a_id and sf_b_id; update the code
that computes sf_load_period and the places assigning sf_a_id/sf_b_id to use
this granularity-based mapping (referencing sf_load_period, sf_granularity_k,
block_K, sf_a_id, sf_b_id).
- Around line 35-37: The computation for sf_load_period and subsequent UTCCP
chunk counts assumes non-zero divisors and can produce sf_load_period == 0 (or
zero chunk counts) leading to divide-by-zero or dropped SF data; update the
calculations that produce sf_load_period (currently sf_load_period =
sf_granularity_k * 4 // block_K) and any derived UTCCP chunk counts to
defensively enforce a minimum of 1 (e.g., wrap with max(1, ...)), and add
explicit guards or assertions before using these values as divisors so code
using k_iters, sf_load_period, and the UTCCP chunk variables cannot divide by
zero or silently skip SF processing. Ensure you modify all uses referenced
(sf_load_period, sf_granularity_k, block_K and the chunk count calculations) so
they never become zero.
---
Nitpick comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled.py`:
- Line 235: The example currently prints full tensors c and ref_c (print(f"{c=},
{ref_c=}")), which will log huge 8192×8192 arrays; replace that line to avoid
dumping full tensors by printing compact diagnostics instead (e.g., shapes and
summary statistics like c.shape, c.dtype, c.mean().item(), c.std().item(), and a
max absolute difference between c and ref_c) or wrap the full-tensor print
behind a verbose/debug flag so normal runs only emit lightweight summaries;
update the print call in gemm_mxfp8_blockscaled.py at the location of variables
c and ref_c accordingly.
In `@src/op/copy.cc`:
- Around line 1319-1326: The shared-memory pointer builder make_smem_ptr
currently uses src->data directly, which fails to apply buffer remapping for
dynamic shared memory; modify the lambda (and the surrounding logic where it
captures src) to check for T.buffer_remap (the same remapping used in
LowerBulkCopy/LowerLDSMCopy) and use the remapped buffer handle instead of
src->data when buffer_remap is present (i.e., for shared.dyn), so the Call to
tvm_access_ptr receives the remapped buffer pointer rather than the original
src->data.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 395739bb-5626-4e6f-b16b-88a2c241c98b
📒 Files selected for processing (3)
examples/gemm_sm100/gemm_mxfp8_blockscaled.pysrc/op/copy.ccsrc/op/copy.h
| Stmt body = stmts[0]; | ||
| for (size_t i = 1; i < stmts.size(); ++i) { | ||
| body = SeqStmt({body, stmts[i]}); | ||
| } |
There was a problem hiding this comment.
Guard against empty stmts vector to avoid undefined behavior.
If total_elements is 0 (e.g., a zero-extent region), num_calls will be 0 and stmts will be empty. Accessing stmts[0] on line 1341 would then be undefined behavior.
🛡️ Proposed fix
// Build a sequence of ptx_tcgen05_cp calls
std::vector<Stmt> stmts;
for (int i = 0; i < num_calls; ++i) {
PrimExpr smem_ptr = make_smem_ptr(flat_offset + i * ELEMENTS_PER_CP);
PrimExpr col_offset = IntImm(DataType::Int(32), i * COLS_PER_CP);
stmts.push_back(Evaluate(
Call(DataType::Void(), ptx_tcgen05_cp(),
{smem_ptr, tmem_ptr, col_offset})));
}
+ if (stmts.empty()) {
+ return Evaluate(0); // No-op for zero-element copy
+ }
+
Stmt body = stmts[0];
for (size_t i = 1; i < stmts.size(); ++i) {
body = SeqStmt({body, stmts[i]});
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/copy.cc` around lines 1341 - 1344, Guard against an empty stmts vector
before accessing stmts[0]: check if stmts.empty() (derived from
total_elements/num_calls) and either return an appropriate no-op/empty Stmt
early or set body to an explicit empty statement instead of dereferencing
stmts[0]; update the block around variables stmts, body and the SeqStmt
construction so the loop only runs when stmts is non-empty to avoid undefined
behavior.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py`:
- Around line 226-230: Cosine similarity is insufficient; add an elementwise
absolute and relative error check between c and ref_c after computing sim:
convert c and ref_c to a higher-precision dtype (e.g., float32), compute
absolute errors (abs_err = |c - ref_c|) and relative errors (rel_err = abs_err /
(|ref_c| + eps)), get max_abs_err and max_rel_err, and assert they are below
chosen thresholds (e.g., max_abs_err < 1e-2 and max_rel_err < 1e-2); keep the
existing cosine_similarity check and include clear failure messages referencing
c, ref_c, sim, max_abs_err, and max_rel_err so blockscaled_gemm_ref / c
mismatches are detected even when angles match.
- Around line 54-57: The code computes sfa_num_chunks and sfb_num_chunks and
allocates SFA_tmem/SFB_tmem assuming block_M and block_N are exact multiples of
128, but the transpose only processes block_M // 128 and block_N // 128 chunks;
add explicit runtime assertions that block_M % 128 == 0 and block_N % 128 == 0
(with clear messages referencing block_M and block_N) before computing
sfa_num_chunks/sfb_num_chunks and allocating SFA_tmem/SFB_tmem to prevent silent
truncation/corruption.
- Around line 37-38: The calculation for sf_load_period using sf_granularity_k *
4 // block_K can produce scale-factor IDs (sf_a_id, sf_b_id) outside the valid
0-3 range for different block_K values; update the logic around sf_load_period,
sf_granularity_k, and the code that computes sf_a_id/sf_b_id so IDs are derived
from the scale-block structure (e.g., index within a scale-block of size
sf_granularity_k or a bounded modulo) rather than raw K-tile indices, and add a
validation/assertion that sf_a_id and sf_b_id are within 0..3 (or clamp/raise a
configuration error) before they are used in blockscaled_gemm so the contract is
always satisfied.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 018d55e1-e3e6-4926-8b00-08f635ad4831
📒 Files selected for processing (1)
examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py
| sf_load_period = sf_granularity_k * 4 // block_K | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's read the file to understand the actual configuration
cat -n examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py | head -150Repository: tile-ai/tilelang
Length of output: 6760
🏁 Script executed:
# Also find and examine the blockscaled_gemm API definition
fd -t f "gemm_op.py" --exec grep -l "blockscaled_gemm" {} \;Repository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
# And check the gemm_tcgen05 file mentioned
fd -t f "gemm_tcgen05.py"Repository: tile-ai/tilelang
Length of output: 137
🏁 Script executed:
# Check for the CUDA template file
fd -t f "tcgen_05.h"Repository: tile-ai/tilelang
Length of output: 92
🏁 Script executed:
# Read the blockscaled_gemm API definition
grep -A 50 "def blockscaled_gemm" tilelang/language/gemm_op.py | head -70Repository: tile-ai/tilelang
Length of output: 1862
🏁 Script executed:
# Read the lowering implementation
cat -n tilelang/tileop/gemm/gemm_tcgen05.py | sed -n '189,250p'Repository: tile-ai/tilelang
Length of output: 1976
🏁 Script executed:
# Read the CUDA template
cat -n src/tl_templates/cuda/tcgen_05.h | sed -n '60,100p'Repository: tile-ai/tilelang
Length of output: 2123
🏁 Script executed:
# Read the main() section of the example file to see actual parameter values
cat -n examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py | tail -100Repository: tile-ai/tilelang
Length of output: 4694
🏁 Script executed:
# Let's verify the mathematical relationship more explicitly
# with the actual parameters from main()
python3 << 'EOF'
# Actual parameters from main()
M, N, K = 8192, 8192, 8192
block_M, block_N, block_K = 128, 256, 128
sf_granularity_k = 128
# Current formula at line 37
sf_load_period_current = sf_granularity_k * 4 // block_K
print(f"Current formula: {sf_granularity_k} * 4 // {block_K} = {sf_load_period_current}")
# Proposed formula
tiles_per_scale = sf_granularity_k // block_K
sf_load_period_proposed = tiles_per_scale * 4
print(f"Proposed formula: ({sf_granularity_k} // {block_K}) * 4 = {sf_load_period_proposed}")
# Number of scale factor blocks
sf_k_blocks = (K + sf_granularity_k - 1) // sf_granularity_k
sf_k_packed = (sf_k_blocks + 3) // 4
print(f"\nScale factor structure:")
print(f" sf_k_blocks (unpacked): {sf_k_blocks}")
print(f" sf_k_packed (uint32): {sf_k_packed}")
print(f" Each uint32 covers: {sf_granularity_k * 4} K elements = {(sf_granularity_k * 4) // block_K} K-tiles")
# K iteration logic
k_iters = (K + block_K - 1) // block_K
print(f"\nK-iteration logic:")
print(f" Total K-tile iterations: {k_iters}")
print(f" Scale factors loaded every {sf_load_period_current} iterations")
print(f" sf_a_id cycles: {list(range(sf_load_period_current))}")
# Scale loading points
load_points = [k for k in range(k_iters) if k % sf_load_period_current == 0]
print(f" Scale factors loaded at k indices: {load_points[:10]}... (showing first 10)")
print(f" Total scale loads: {len(load_points)}")
# Test with different block_K to show the problem
print(f"\n--- Testing with different block_K values ---")
for test_block_K in [64, 128, 256]:
print(f"\nblock_K = {test_block_K}:")
sf_load_period_test = sf_granularity_k * 4 // test_block_K
print(f" Current formula produces sf_load_period = {sf_load_period_test}")
print(f" sf_a_id would cycle through range(0, {sf_load_period_test})")
if sf_load_period_test > 4:
print(f" ⚠️ INVALID: sf_a_id exceeds valid range 0-3!")
if sf_granularity_k % test_block_K == 0:
tiles_per_scale_test = sf_granularity_k // test_block_K
sf_load_period_safe = tiles_per_scale_test * 4
print(f" Proposed formula produces sf_load_period = {sf_load_period_safe}")
print(f" sf_a_id = (k // {tiles_per_scale_test}) % 4 ✓ Always in range 0-3")
else:
print(f" ⚠️ sf_granularity_k ({sf_granularity_k}) not divisible by block_K ({test_block_K})")
EOFRepository: tile-ai/tilelang
Length of output: 1162
🏁 Script executed:
# Let's also check the blockscaled_gemm lowering to see if it validates sf_a_id/sf_b_id
grep -B 5 -A 15 "sf_a_id\|sf_b_id" tilelang/tileop/gemm/gemm_tcgen05.py | head -50Repository: tile-ai/tilelang
Length of output: 1354
🏁 Script executed:
# Check if there are any assertions on sf_a_id in the MMA emitter
fd -t f "tcgen05" -type f | xargs grep -l "sf_a_id\|sf_b_id" 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 232
🏁 Script executed:
# Let's check the actual MMA emitter implementation
grep -A 20 "def tcgen05mma_blockscaled" tilelang/tileop/mma/tcgen05_mma.py 2>/dev/null | head -40Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check if tcgen05mma_blockscaled validates sf_a_id
rg "tcgen05mma_blockscaled" --type py -A 30 | head -60Repository: tile-ai/tilelang
Length of output: 3945
🏁 Script executed:
# Check the CUDA template for any ID validation
rg "sf_a_id|sf_b_id" src/tl_templates/cuda/ -A 2 -B 2Repository: tile-ai/tilelang
Length of output: 42
Ensure scale factor IDs remain in valid range (0-3) for all block_K configurations.
The blockscaled_gemm API strictly requires sf_a_id and sf_b_id to be in range 0-3. Line 37's formula sf_load_period = sf_granularity_k * 4 // block_K happens to be safe for the current example (block_K=128 equals sf_granularity_k=128), but will violate this contract for other block_K values. For example, if block_K=64, the formula produces sf_load_period=8, causing lines 114-115 to emit IDs up to 7. Add validation to constrain the configuration, then derive scale IDs from the scale block structure rather than raw K-tile indices.
Suggested fix
- sf_load_period = sf_granularity_k * 4 // block_K
+ if sf_granularity_k % block_K != 0:
+ raise ValueError("sf_granularity_k must be a multiple of block_K")
+ tiles_per_scale = sf_granularity_k // block_K
+ sf_load_period = tiles_per_scale * 4
...
- sf_a_id=k % sf_load_period,
- sf_b_id=k % sf_load_period,
+ sf_a_id=(k // tiles_per_scale) % 4,
+ sf_b_id=(k // tiles_per_scale) % 4,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py` around lines 37 -
38, The calculation for sf_load_period using sf_granularity_k * 4 // block_K can
produce scale-factor IDs (sf_a_id, sf_b_id) outside the valid 0-3 range for
different block_K values; update the logic around sf_load_period,
sf_granularity_k, and the code that computes sf_a_id/sf_b_id so IDs are derived
from the scale-block structure (e.g., index within a scale-block of size
sf_granularity_k or a bounded modulo) rather than raw K-tile indices, and add a
validation/assertion that sf_a_id and sf_b_id are within 0..3 (or clamp/raise a
configuration error) before they are used in blockscaled_gemm so the contract is
always satisfied.
| sfa_num_chunks = block_M // 128 | ||
| sfb_num_chunks = block_N // 128 | ||
| SFA_tmem = T.alloc_tmem([32, sfa_num_chunks * 4], "uint32") | ||
| SFB_tmem = T.alloc_tmem([32, sfb_num_chunks * 4], "uint32") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "gemm_mxfp8_blockscaled_1_128_128.py" -type fRepository: tile-ai/tilelang
Length of output: 117
🏁 Script executed:
wc -l ./examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.pyRepository: tile-ai/tilelang
Length of output: 121
🏁 Script executed:
cat -n ./examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.pyRepository: tile-ai/tilelang
Length of output: 11309
Enforce multiples of 128 for block_M and block_N.
Lines 54–57 allocate TMEM based on floor-divided chunk counts, but the transpose loop (lines 99–102) only processes block_M // 128 and block_N // 128 chunks of 128 elements each. When block_M or block_N are not multiples of 128, the remaining elements are never transposed but are still copied wholesale into undersized TMEM buffers at lines 103–104, leading to data corruption. Add explicit assertions that both must be multiples of 128:
+ if block_M % 128 != 0 or block_N % 128 != 0:
+ raise ValueError("block_M and block_N must be multiples of 128")
sfa_num_chunks = block_M // 128
sfb_num_chunks = block_N // 128📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| sfa_num_chunks = block_M // 128 | |
| sfb_num_chunks = block_N // 128 | |
| SFA_tmem = T.alloc_tmem([32, sfa_num_chunks * 4], "uint32") | |
| SFB_tmem = T.alloc_tmem([32, sfb_num_chunks * 4], "uint32") | |
| if block_M % 128 != 0 or block_N % 128 != 0: | |
| raise ValueError("block_M and block_N must be multiples of 128") | |
| sfa_num_chunks = block_M // 128 | |
| sfb_num_chunks = block_N // 128 | |
| SFA_tmem = T.alloc_tmem([32, sfa_num_chunks * 4], "uint32") | |
| SFB_tmem = T.alloc_tmem([32, sfb_num_chunks * 4], "uint32") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py` around lines 54 -
57, The code computes sfa_num_chunks and sfb_num_chunks and allocates
SFA_tmem/SFB_tmem assuming block_M and block_N are exact multiples of 128, but
the transpose only processes block_M // 128 and block_N // 128 chunks; add
explicit runtime assertions that block_M % 128 == 0 and block_N % 128 == 0 (with
clear messages referencing block_M and block_N) before computing
sfa_num_chunks/sfb_num_chunks and allocating SFA_tmem/SFB_tmem to prevent silent
truncation/corruption.
| ref_c = blockscaled_gemm_ref(a, b, sfa_unpacked[:, :sf_k_blocks], sfb_unpacked_coarse, sf_granularity_k, sf_granularity_n).to(torch.bfloat16) | ||
| sim = cosine_similarity(c, ref_c) | ||
| print(f"Output shape: {c.shape}, dtype: {c.dtype}") | ||
| print(f"Cosine similarity: {sim.item():.6f}") | ||
| assert sim > 0.99, f"Cosine similarity too low: {sim.item():.6f}" |
There was a problem hiding this comment.
Cosine similarity alone is too weak for scale-factor validation.
A uniform exponent or magnitude bug on the block-scaling path keeps the angle almost unchanged, so Lines 227-230 can pass while the output is still numerically wrong. Please add an elementwise absolute/relative-error check alongside the cosine gate.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_mxfp8_blockscaled_1_128_128.py` around lines 226 -
230, Cosine similarity is insufficient; add an elementwise absolute and relative
error check between c and ref_c after computing sim: convert c and ref_c to a
higher-precision dtype (e.g., float32), compute absolute errors (abs_err = |c -
ref_c|) and relative errors (rel_err = abs_err / (|ref_c| + eps)), get
max_abs_err and max_rel_err, and assert they are below chosen thresholds (e.g.,
max_abs_err < 1e-2 and max_rel_err < 1e-2); keep the existing cosine_similarity
check and include clear failure messages referencing c, ref_c, sim, max_abs_err,
and max_rel_err so blockscaled_gemm_ref / c mismatches are detected even when
angles match.
This commit enhances the performance of GEMM operations on the SM100 architecture by leveraging block scaling and optimized memory access patterns.
Summary by CodeRabbit
New Features
API / Intrinsics
Language / Exports