-
Notifications
You must be signed in to change notification settings - Fork 279
Open
Description
pip_freeze_sageattention_debug.txt
Summary
On sm89 (Ada / RTX 4060 Ti) using PyTorch 2.6.0 + CUDA 12.4, the SageAttention FP8 QK kernels behave inconsistently:
-
Minimal FP8 kernel tests pass
-
Minimal BF16 → FP8 tests pass (unexpected)
-
Real workloads (ComfyUI / SDXL / Flux) crash with:
RuntimeError: CUDA error: unspecified launch failure- Or corrupted outputs / NaNs
Disabling the FP8 path on sm89 and forcing FP16 kernels completely resolves all instability.
This strongly suggests an issue with the sm89 FP8 QK kernel or its BF16 → FP16 handling, padding, or quantization granularity code paths.
Environment
GPU / CUDA
- GPU: RTX 4060 Ti
- Compute capability: 8.9 (sm89)
- NVIDIA Driver: 550.163.01
- CUDA Runtime: 12.4
- nvcc: 12.4.131
Python / Torch
Python: 3.13.5
torch: 2.6.0+cu124
torch.cuda version: 12.4
cuDNN: 90100
Triton
triton 3.2.0
SageAttention
Installed wheel: sageattention 2.2.0
core.py path: .../site-packages/sageattention/core.py
SageAttention arch detection
['sm89']
Reproduction
Minimal FP8 test (FP16 input) → SUCCESS
out = sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout="HND", qk_quant_gran="per_warp")Minimal BF16 → FP8 test → SUCCESS (unexpected)
Larger shapes or real models → FAILURE
Example:
(B,H,L,D) = (1,16,2048,128)
dtype=torch.bfloat16
→ RuntimeError: CUDA error: unspecified launch failureTriton kernels on sm89 also unstable.
Workaround (100% stable)
In core.py, force sm89 to use FP16 CUDA:
elif arch == "sm89":
return sageattn_qk_int8_pv_fp16_cuda(...)After applying this:
✔️ FP16 works
✔️ BF16 works
✔️ ComfyUI / SDXL stable
❌ FP8 disabled, but no crashes
Hypotheses
- FP8 QK kernel on sm89 is executing the wrong accumulator path (code seems tuned for sm90).
- BF16 → FP16 conversion inside the kernel may be misaligned.
- Padding to 64 / 128 dims may overflow the FP8 MMA path.
per_warpint8 quantization may not match Ada tensor core requirements.
Metadata
Metadata
Assignees
Labels
No labels