Skip to content

FP8 QK kernels unstable on Ada (sm89, RTX 4060 Ti) — BF16/FP8 path causes launch failures, FP16 fallback fully resolves #309

@NasBuk

Description

@NasBuk

pip_freeze_sageattention_debug.txt

Summary

On sm89 (Ada / RTX 4060 Ti) using PyTorch 2.6.0 + CUDA 12.4, the SageAttention FP8 QK kernels behave inconsistently:

  • Minimal FP8 kernel tests pass

  • Minimal BF16 → FP8 tests pass (unexpected)

  • Real workloads (ComfyUI / SDXL / Flux) crash with:

    • RuntimeError: CUDA error: unspecified launch failure
    • Or corrupted outputs / NaNs

Disabling the FP8 path on sm89 and forcing FP16 kernels completely resolves all instability.

This strongly suggests an issue with the sm89 FP8 QK kernel or its BF16 → FP16 handling, padding, or quantization granularity code paths.


Environment

GPU / CUDA

  • GPU: RTX 4060 Ti
  • Compute capability: 8.9 (sm89)
  • NVIDIA Driver: 550.163.01
  • CUDA Runtime: 12.4
  • nvcc: 12.4.131

Python / Torch

Python: 3.13.5
torch: 2.6.0+cu124
torch.cuda version: 12.4
cuDNN: 90100

Triton

triton 3.2.0

SageAttention

Installed wheel: sageattention 2.2.0
core.py path: .../site-packages/sageattention/core.py

SageAttention arch detection

['sm89']

Reproduction

Minimal FP8 test (FP16 input) → SUCCESS

out = sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout="HND", qk_quant_gran="per_warp")

Minimal BF16 → FP8 test → SUCCESS (unexpected)

Larger shapes or real models → FAILURE

Example:

(B,H,L,D) = (1,16,2048,128)
dtype=torch.bfloat16RuntimeError: CUDA error: unspecified launch failure

Triton kernels on sm89 also unstable.


Workaround (100% stable)

In core.py, force sm89 to use FP16 CUDA:

elif arch == "sm89":
    return sageattn_qk_int8_pv_fp16_cuda(...)

After applying this:

✔️ FP16 works
✔️ BF16 works
✔️ ComfyUI / SDXL stable
❌ FP8 disabled, but no crashes


Hypotheses

  • FP8 QK kernel on sm89 is executing the wrong accumulator path (code seems tuned for sm90).
  • BF16 → FP16 conversion inside the kernel may be misaligned.
  • Padding to 64 / 128 dims may overflow the FP8 MMA path.
  • per_warp int8 quantization may not match Ada tensor core requirements.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions