GPU-efficient LLM inference on constrained hardware — custom Triton causal attention kernel with INT8 weight-only quantization. Targets 6 GB VRAM mid-tier GPUs.
Standard PyTorch attention materialises the full T × T score matrix in HBM — an O(T²) memory footprint that constrains batch size and sequence length on 6 GB GPUs. fp16 linear layers add further memory pressure at inference time, where the backward pass is never needed.
Constraints: single GPU (6 GB VRAM) · inference-only (no backward) · fp16 activations · causal self-attention.
Blocked QKᵀ + online softmax + V accumulation in SRAM — no intermediate attention matrix written to HBM.
- Online softmax (numerically stable, single-pass)
- Causal mask applied in-kernel with no separate allocation
- Fused into one kernel dispatch — no intermediate tensors between Q@K, softmax, and @V
- Head dimension fixed at 64; block sizes tunable per hardware
Weights quantised to INT8 per output channel at load time, dequantised on the fly during forward pass.
- ~4× weight memory reduction vs fp16 baseline
- fp16 activations preserved — no activation quantisation error
- Per-output-channel scaling:
scale = max(|W|, axis=1) / 127
Benchmarked on RTX 4060 (6 GB VRAM, Ada Lovelace) — the target deployment hardware. CUDA Events, 100 iterations, fp16.
| Batch | Heads | Seq Len | PyTorch (ms) | Triton (ms) | Speedup | VRAM: PyTorch | VRAM: Triton |
|---|---|---|---|---|---|---|---|
| 1 | 4 | 128 | 0.1274 | 0.0301 | 4.23× | 9.18 MB | 8.78 MB |
| 1 | 4 | 256 | 0.1359 | 0.0247 | 5.50× | 10.75 MB | 9.04 MB |
| 1 | 4 | 512 | 0.1347 | 0.0677 | 1.99× | 16.65 MB | 9.57 MB |
| 1 | 8 | 128 | 0.1337 | 0.0193 | 6.92× | 9.83 MB | 9.04 MB |
| 1 | 8 | 256 | 0.2016 | 0.0425 | 4.75× | 12.98 MB | 9.57 MB |
| 1 | 8 | 512 | 0.2130 | 0.1200 | 1.77× | 24.77 MB | 10.62 MB |
| 2 | 4 | 128 | 0.1430 | 0.0236 | 6.05× | 9.83 MB | 9.04 MB |
| 2 | 4 | 256 | 0.1330 | 0.0237 | 5.62× | 12.98 MB | 9.57 MB |
| 2 | 4 | 512 | 0.2062 | 0.0676 | 3.05× | 24.77 MB | 10.62 MB |
| 2 | 8 | 128 | 0.3702 | 0.0215 | 17.23× | 11.14 MB | 9.57 MB |
| 2 | 8 | 256 | 0.1682 | 0.0406 | 4.14× | 17.43 MB | 10.62 MB |
| 2 | 8 | 512 | 0.5204 | 0.1016 | 5.12× | 41.03 MB | 12.71 MB |
Peak speedup: 17.23× at B=2, H=8, seq=128. At B=2, H=8, seq=512: VRAM drops from 41.0 MB → 12.7 MB (3.2× reduction) — the primary benefit at longer sequences on constrained hardware.
The speedups here exceed those on higher-end hardware (RTX 5000 Ada, 32 GB) because memory pressure is more severe on 6 GB — the Triton kernel's SRAM tiling advantage is amplified when HBM bandwidth is the bottleneck.
| Metric | PyTorch | Optimised | Delta |
|---|---|---|---|
| Latency | 0.7453 ms | 0.1275 ms | −82.9% (5.84× faster) |
| Peak VRAM | 16.40 MB | 11.69 MB | −28.7% |
5.84× faster end-to-end with 28.7% less VRAM on the 6 GB RTX 4060. Both gains compound: Triton attention eliminates HBM round-trips, INT8 weight compression halves the linear layer's memory footprint.
The profiler trace confirms the kernel breakdown:
attention_kernel 37.0 µs 21.1% (single fused Triton dispatch — no materialised T×T matrix)
aten::bmm 30.9 µs 17.6% (PyTorch baseline matmuls)
aten::copy_ 31.9 µs 18.2% (INT8 dequant cast)
aten::triu 20.0 µs 11.4% (causal mask allocation in PyTorch path)
The Triton attention kernel dispatches once and uses no intermediate buffers. The PyTorch path allocates and fills a full T×T matrix (aten::triu), runs two separate matmuls, and a softmax — all hitting HBM repeatedly.
| This system | Production equivalent |
|---|---|
| Tiled attention in SRAM | FlashAttention — used in vLLM, HuggingFace, llama.cpp |
| INT8 weight-only quantisation | GPTQ, AWQ weight-only quant for 4/8-bit inference |
| Causal mask in-kernel | Standard in all autoregressive inference engines |
| 6 GB VRAM target | RTX 4060, 3060, laptop GPUs — the edge deployment reality |
The Triton attention kernel is validated on a real transformer end-to-end.
anviit/nanogpt-triton-inference
Integrates this kernel into a nanoGPT implementation and benchmarks autoregressive inference throughput and memory usage against the PyTorch baseline.
✅ Triton attention matches PyTorch rtol=1e-2, atol=1e-2, fp16
✅ QuantLinear output matches fp16 Linear rtol=2e-1, atol=2e-1 (expected for INT8)
triton-int8-inference/
├── kernels/
│ ├── __init__.py
│ ├── triton_flash.py # Triton causal attention kernel
│ └── quant_linear.py # INT8 weight-only linear layer
├── models/
│ ├── __init__.py
│ └── mini_block.py # TorchMiniBlock vs OptimizedMiniBlock
├── benchmarks/
│ ├── attention_bench.py # Attention-only head-to-head → results/attention.json
│ ├── block_bench.py # End-to-end block comparison
│ ├── baseline.py # PyTorch attention standalone baseline
│ ├── test_triton_attention.py # Correctness — Triton vs PyTorch
│ ├── test_quant_linear.py # Correctness — INT8 vs fp16 Linear
│ └── results/
│ └── attention.json # Full benchmark results
├── profiler/
│ ├── profile_block.py # PyTorch profiler + NVTX trace
│ └── mini_block_trace.json # Chrome trace output
└── pyproject.toml # pip install -e . for local imports
pip install torch triton nvtx
pip install -e . # makes kernels/ and models/ importable from anywhere in the repoRequires a CUDA-capable NVIDIA GPU (Ampere or newer). Benchmarks run on RTX 4060 (6 GB, Ada Lovelace).
# Correctness tests
python benchmarks/test_triton_attention.py # Triton vs PyTorch attention
python benchmarks/test_quant_linear.py # INT8 vs fp16 linear
# Benchmarks
python benchmarks/attention_bench.py # Attention-only speedup table → results/attention.json
python benchmarks/block_bench.py # End-to-end block comparison
# Profiling (generates Chrome trace)
python profiler/profile_block.py # outputs profiler/mini_block_trace.json
# open with chrome://tracing or perfetto.devNote: First run of any Triton kernel will be slow (~10–30s) while Triton JIT-compiles and caches. Subsequent runs are fast.
Why online softmax?
Avoids materialising the full T × T score matrix in HBM. Each tile of scores is computed, the running max and normaliser are updated incrementally (numerically stable), and the output accumulator is corrected — identical to the core algorithm in FlashAttention (Dao et al., 2022). The profiler confirms this: aten::triu (causal mask allocation) appears only in the PyTorch path and consumes 11.4% of CUDA time — entirely absent from the Triton path.
Why weight-only INT8, not activation quantisation? Activation quantisation requires per-token dynamic range estimation and is sensitive to outlier activations — a known problem in LLMs. Weight-only quantisation is static, cheap at runtime, and loses only ~0.5–1% accuracy on standard benchmarks while cutting weight memory by ~4×.
Why not fuse dequant+GEMM?
The dequant path (qweight.float() * scale → x.float() @ w.T) is unfused — aten::copy_ (18.2% of CUDA time) is the cast overhead. A production fix integrates dequant into the GEMM kernel (GPTQ-triton, Marlin), eliminating this entirely. That is the logical next step.
- FlashAttention: Fast and Memory-Efficient Exact Attention — Dao et al., 2022
- GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers — Frantar et al., 2022
- Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations — Tillet et al., 2019