|
| 1 | +<!-- |
| 2 | +SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +SPDX-License-Identifier: Apache-2.0 |
| 4 | +--> |
| 5 | + |
| 6 | +# KVBM Kernel Testing Architecture |
| 7 | + |
| 8 | +This document explains how CUDA kernels are tested through PyTorch bindings using a multi-layer FFI architecture. |
| 9 | + |
| 10 | +## What is FFI? |
| 11 | + |
| 12 | +**FFI (Foreign Function Interface)** is a mechanism that allows code written in one programming language to call functions written in another language. In this project, we use FFI in two places: |
| 13 | + |
| 14 | +1. **Rust ↔ CUDA**: Rust calls C/C++ CUDA functions via the C ABI |
| 15 | +2. **Python ↔ Rust**: Python calls Rust functions via PyO3 (a Rust-Python bridge) |
| 16 | + |
| 17 | +This enables us to write performance-critical code in CUDA while making it accessible from high-level Python/PyTorch for testing and integration. |
| 18 | + |
| 19 | +## Architecture Layers |
| 20 | + |
| 21 | +``` |
| 22 | +┌─────────────────────────────────────────────────────────────┐ |
| 23 | +│ Layer 4: Python Tests (test_tensor_kernels.py) │ |
| 24 | +│ • Creates PyTorch tensors as test data │ |
| 25 | +│ • Uses pure PyTorch ops as reference implementation │ |
| 26 | +│ • Validates CUDA kernels match PyTorch reference │ |
| 27 | +└──────────────────┬──────────────────────────────────────────┘ |
| 28 | + │ import |
| 29 | + │ from kvbm import kernels |
| 30 | +┌──────────────────▼──────────────────────────────────────────┐ |
| 31 | +│ Layer 3: PyO3 Bindings (src/kernels.rs) │ |
| 32 | +│ • Exposes Rust/CUDA functions to Python │ |
| 33 | +│ • Extracts GPU pointers from PyTorch tensors │ |
| 34 | +│ • Validates shapes, dtypes, device placement │ |
| 35 | +│ • Functions: block_to_universal(), universal_to_block() │ |
| 36 | +└──────────────────┬──────────────────────────────────────────┘ |
| 37 | + │ calls (Rust FFI) |
| 38 | + │ use kvbm_kernels |
| 39 | +┌──────────────────▼──────────────────────────────────────────┐ |
| 40 | +│ Layer 2: Rust FFI (lib/kvbm-kernels/src/tensor_kernels.rs) │ |
| 41 | +│ • Wraps CUDA kernels with safe Rust API │ |
| 42 | +│ • Manages CUDA contexts, streams, memory │ |
| 43 | +│ • Exports: universal_from_block(), block_from_universal() │ |
| 44 | +└──────────────────┬──────────────────────────────────────────┘ |
| 45 | + │ extern "C" calls |
| 46 | + │ unsafe { cuda_function(...) } |
| 47 | +┌──────────────────▼──────────────────────────────────────────┐ |
| 48 | +│ Layer 1: CUDA Kernels (lib/kvbm-kernels/cuda/*.cu) │ |
| 49 | +│ • Raw CUDA kernel implementations │ |
| 50 | +│ • Converts between KV cache layouts on GPU │ |
| 51 | +│ • Files: tensor_kernels.cu, vectorized_copy.cu │ |
| 52 | +└─────────────────────────────────────────────────────────────┘ |
| 53 | +``` |
| 54 | + |
| 55 | +## Detailed Layer Breakdown |
| 56 | + |
| 57 | +### Layer 1: CUDA Kernels (`lib/kvbm-kernels/cuda/`) |
| 58 | + |
| 59 | +**Purpose**: Implements the actual GPU computation |
| 60 | + |
| 61 | +**Files**: |
| 62 | +- `tensor_kernels.cu` - Converts between Stacked (vLLM), Operational (TensorRT-LLM), and Universal (Dynamo) layouts |
| 63 | +- `vectorized_copy.cu` - Optimized memory copy operations |
| 64 | + |
| 65 | +**Example**: |
| 66 | +```cuda |
| 67 | +// CUDA kernel that does the actual work |
| 68 | +__global__ void universal_from_block_kernel( |
| 69 | + void** universal_ptrs, |
| 70 | + const void** block_ptrs, |
| 71 | + size_t nb, size_t nh, size_t nl, size_t no, size_t nt, size_t hd |
| 72 | +) { |
| 73 | + // GPU code that rearranges memory layouts |
| 74 | +} |
| 75 | +``` |
| 76 | + |
| 77 | +### Layer 2: Rust FFI (`lib/kvbm-kernels/src/`) |
| 78 | + |
| 79 | +**Purpose**: Wraps CUDA kernels with type-safe Rust API |
| 80 | + |
| 81 | +**Files**: |
| 82 | +- `lib.rs` - Module initialization, loads CUDA fatbin files |
| 83 | +- `tensor_kernels.rs` - Rust wrappers for CUDA functions |
| 84 | + |
| 85 | +**Example**: |
| 86 | +```rust |
| 87 | +// Rust function that calls CUDA kernel via FFI |
| 88 | +pub unsafe extern "C" fn universal_from_block( |
| 89 | + universal_ptrs: *const *mut c_void, |
| 90 | + block_ptrs: *const *const c_void, |
| 91 | + nb: usize, nh: usize, nl: usize, no: usize, nt: usize, hd: usize, |
| 92 | + dtype: TensorDataType, |
| 93 | + layout: BlockLayout, |
| 94 | + stream: cudaStream_t, |
| 95 | +) -> cudaError_t { |
| 96 | + // Call CUDA kernel |
| 97 | + // Return CUDA status code |
| 98 | +} |
| 99 | +``` |
| 100 | + |
| 101 | +**Why Rust?** |
| 102 | +- Memory safety without runtime overhead |
| 103 | +- Strong type system catches errors at compile time |
| 104 | +- Excellent FFI support for calling C/CUDA code |
| 105 | + |
| 106 | +### Layer 3: PyO3 Bindings (`lib/bindings/kvbm/src/kernels.rs`) |
| 107 | + |
| 108 | +**Purpose**: Exposes Rust/CUDA functions to Python |
| 109 | + |
| 110 | +**Example**: |
| 111 | +```rust |
| 112 | +#[pyfunction] |
| 113 | +unsafe fn block_to_universal( |
| 114 | + py: Python<'_>, |
| 115 | + blocks: &Bound<'_, PyAny>, // PyTorch tensors |
| 116 | + universals: &Bound<'_, PyAny>, // PyTorch tensors |
| 117 | + layout: &str, |
| 118 | +) -> PyResult<()> { |
| 119 | + // 1. Extract GPU pointers from PyTorch tensors |
| 120 | + let ptr: usize = tensor.call_method0("data_ptr")?; |
| 121 | + let shape: Vec<usize> = tensor.getattr("shape")?.extract()?; |
| 122 | + |
| 123 | + // 2. Validate tensor properties |
| 124 | + if !tensor.getattr("is_cuda")?.extract()? { |
| 125 | + return Err(PyValueError::new_err("Tensor must be on CUDA")); |
| 126 | + } |
| 127 | + |
| 128 | + // 3. Call Rust/CUDA function |
| 129 | + let status = universal_from_block( |
| 130 | + universal_ptrs, block_ptrs, nb, nh, nl, no, nt, hd, |
| 131 | + dtype, layout_enum, stream |
| 132 | + ); |
| 133 | + |
| 134 | + // 4. Handle errors and synchronize |
| 135 | + if status != cudaSuccess { |
| 136 | + return Err(PyRuntimeError::new_err("CUDA error")); |
| 137 | + } |
| 138 | + stream.synchronize()?; |
| 139 | + Ok(()) |
| 140 | +} |
| 141 | +``` |
| 142 | + |
| 143 | +**What PyO3 Does**: |
| 144 | +- Converts Python objects to Rust types |
| 145 | +- Extracts raw GPU memory pointers from PyTorch tensors |
| 146 | +- Validates input (shapes, dtypes, device placement) |
| 147 | +- Handles errors and converts them to Python exceptions |
| 148 | +- Manages CUDA stream synchronization |
| 149 | + |
| 150 | +### Layer 4: Python Tests (`lib/bindings/kvbm/tests/`) |
| 151 | + |
| 152 | +**Purpose**: Validates kernel correctness using PyTorch as reference |
| 153 | + |
| 154 | +**Files**: |
| 155 | +- `test_tensor_kernels.py` - Comprehensive kernel tests |
| 156 | + |
| 157 | +**Testing Strategy**: |
| 158 | +1. Create random PyTorch CUDA tensors |
| 159 | +2. Generate **reference output** using pure PyTorch operations (slicing, permuting) |
| 160 | +3. Run **CUDA kernel** through PyO3 bindings |
| 161 | +4. Compare kernel output vs PyTorch reference with appropriate tolerances |
| 162 | + |
| 163 | +**Example Test**: |
| 164 | +```python |
| 165 | +import torch |
| 166 | +from kvbm import kernels as ctk |
| 167 | + |
| 168 | +def test_block_universal_roundtrip(): |
| 169 | + # 1. Create test data |
| 170 | + device = torch.device("cuda:0") |
| 171 | + universals = [torch.randn(3, 2, 2, 4, 5, device=device)] # [nh, nl, no, nt, hd] |
| 172 | + |
| 173 | + # 2. Reference: Convert using pure PyTorch |
| 174 | + def _make_blocks(universal, layout="NHD"): |
| 175 | + nh, nl, no, nt, hd = universal.shape |
| 176 | + blocks = [] |
| 177 | + for layer in range(nl): |
| 178 | + for outer in range(no): |
| 179 | + slice_ = universal[:, layer, outer, :, :] # [nh, nt, hd] |
| 180 | + block = slice_.permute(1, 0, 2) # [nt, nh, hd] for NHD |
| 181 | + blocks.append(block) |
| 182 | + return blocks |
| 183 | + |
| 184 | + blocks = _make_blocks(universals[0], "NHD") |
| 185 | + |
| 186 | + # 3. Test: Run CUDA kernel |
| 187 | + outputs = [torch.empty_like(universals[0])] |
| 188 | + ctk.block_to_universal(blocks, outputs, "NHD") # ← Calls CUDA via PyO3! |
| 189 | + torch.cuda.synchronize() |
| 190 | + |
| 191 | + # 4. Validate: CUDA output should match PyTorch reference |
| 192 | + assert torch.allclose(outputs[0], universals[0], atol=1e-5, rtol=1e-5) |
| 193 | +``` |
| 194 | + |
| 195 | +## Why This Architecture? |
| 196 | + |
| 197 | +### Separation of Concerns |
| 198 | +- **CUDA**: Performance-critical GPU code |
| 199 | +- **Rust**: Safe systems programming, manages memory/contexts |
| 200 | +- **Python**: High-level testing, integration with ML frameworks |
| 201 | + |
| 202 | +### Testing Benefits |
| 203 | +1. **Ground Truth**: Pure PyTorch operations are easy to understand and verify |
| 204 | +2. **Black Box**: Tests don't need to know CUDA implementation details |
| 205 | +3. **Comprehensive**: Can test many configurations (dtypes, layouts, backends) |
| 206 | +4. **Debuggable**: When tests fail, can compare intermediate results |
| 207 | + |
| 208 | +### Development Workflow |
| 209 | +```bash |
| 210 | +# 1. Modify CUDA kernel |
| 211 | +vim lib/kvbm-kernels/cuda/tensor_kernels.cu |
| 212 | + |
| 213 | +# 2. Rebuild (compiles CUDA, Rust, and Python bindings) |
| 214 | +cd lib/bindings/kvbm |
| 215 | +cargo build --release |
| 216 | + |
| 217 | +# 3. Run tests |
| 218 | +pytest tests/test_tensor_kernels.py -v |
| 219 | + |
| 220 | +# 4. If tests fail, debug with PyTorch |
| 221 | +python -c "import torch; from kvbm import kernels; ..." |
| 222 | +``` |
| 223 | + |
| 224 | +## Running the Tests |
| 225 | + |
| 226 | +### Prerequisites |
| 227 | +```bash |
| 228 | +# CUDA toolkit (for GPU tests) |
| 229 | +# PyTorch with CUDA support |
| 230 | +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 |
| 231 | + |
| 232 | +# Install kvbm package with dev dependencies |
| 233 | +cd lib/bindings/kvbm |
| 234 | +pip install -e ".[dev]" |
| 235 | +``` |
| 236 | + |
| 237 | +### Run All Tests |
| 238 | +```bash |
| 239 | +cd lib/bindings/kvbm |
| 240 | +pytest tests/ -v |
| 241 | +``` |
| 242 | + |
| 243 | +### Run Specific Test |
| 244 | +```bash |
| 245 | +pytest tests/test_tensor_kernels.py::test_block_universal_roundtrip -v |
| 246 | +``` |
| 247 | + |
| 248 | +### Run with Specific Parameters |
| 249 | +```bash |
| 250 | +# Test only NHD layout with float32 |
| 251 | +pytest tests/test_tensor_kernels.py::test_block_universal_roundtrip[NHD-torch.float32] -v |
| 252 | +``` |
| 253 | + |
| 254 | +## Test Coverage |
| 255 | + |
| 256 | +### `test_block_universal_roundtrip` |
| 257 | +- **Tests**: `block_to_universal()` and `universal_to_block()` |
| 258 | +- **Layouts**: NHD (vLLM), HND |
| 259 | +- **Dtypes**: float16, bfloat16, float32, float64 |
| 260 | +- **Validates**: Lossless round-trip conversion |
| 261 | + |
| 262 | +### `test_operational_roundtrip` |
| 263 | +- **Tests**: `block_to_operational()` and `operational_to_block()` |
| 264 | +- **Validates**: Correct flattening/unflattening of block data |
| 265 | + |
| 266 | +### `test_operational_backends` |
| 267 | +- **Tests**: Different memcpy backends (kernel, async, batch, auto) |
| 268 | +- **Validates**: All backends produce correct results |
| 269 | + |
| 270 | +### Error Handling Tests |
| 271 | +- `test_universal_shape_mismatch` - Rejects incorrect shapes |
| 272 | +- `test_dtype_mismatch_error` - Rejects mixed dtypes |
| 273 | +- `test_non_cuda_tensor_error` - Rejects CPU tensors |
| 274 | +- `test_empty_batch_noop` - Handles empty inputs gracefully |
| 275 | + |
| 276 | +## Debugging Tips |
| 277 | + |
| 278 | +### Enable CUDA Error Checking |
| 279 | +```python |
| 280 | +import torch |
| 281 | +torch.cuda.set_sync_debug_mode(1) # Synchronous CUDA calls for debugging |
| 282 | +``` |
| 283 | + |
| 284 | +### Compare Intermediate Results |
| 285 | +```python |
| 286 | +# Get reference |
| 287 | +blocks_ref = _make_blocks(universal, "NHD") |
| 288 | + |
| 289 | +# Run kernel |
| 290 | +outputs = [torch.empty_like(universal)] |
| 291 | +ctk.block_to_universal(blocks, outputs, "NHD") |
| 292 | + |
| 293 | +# Compare specific elements |
| 294 | +print(f"Max diff: {(outputs[0] - universal).abs().max()}") |
| 295 | +print(f"Mean diff: {(outputs[0] - universal).abs().mean()}") |
| 296 | +``` |
| 297 | + |
| 298 | +### Check Tensor Properties |
| 299 | +```python |
| 300 | +def inspect_tensor(t, name): |
| 301 | + print(f"{name}:") |
| 302 | + print(f" shape: {t.shape}") |
| 303 | + print(f" dtype: {t.dtype}") |
| 304 | + print(f" device: {t.device}") |
| 305 | + print(f" is_contiguous: {t.is_contiguous()}") |
| 306 | + print(f" data_ptr: 0x{t.data_ptr():x}") |
| 307 | +``` |
| 308 | + |
| 309 | +## Common Issues |
| 310 | + |
| 311 | +### Issue: "Tensor must be contiguous" |
| 312 | +**Solution**: Call `.contiguous()` before passing to kernel |
| 313 | +```python |
| 314 | +tensor = tensor.contiguous() |
| 315 | +ctk.block_to_universal(blocks, [tensor], "NHD") |
| 316 | +``` |
| 317 | + |
| 318 | +### Issue: "Mixed dtype error" |
| 319 | +**Solution**: Ensure all tensors in a batch have the same dtype |
| 320 | +```python |
| 321 | +# Bad: mixed dtypes |
| 322 | +blocks = [torch.randn(..., dtype=torch.float32), torch.randn(..., dtype=torch.float16)] |
| 323 | + |
| 324 | +# Good: consistent dtype |
| 325 | +blocks = [torch.randn(..., dtype=torch.float32) for _ in range(n)] |
| 326 | +``` |
| 327 | + |
| 328 | +### Issue: "Shape mismatch" |
| 329 | +**Solution**: Verify dimensions match expected layout |
| 330 | +```python |
| 331 | +# For NHD layout, each block should be [nt, nh, hd] |
| 332 | +# For universal, should be [nh, nl, no, nt, hd] |
| 333 | +``` |
| 334 | + |
| 335 | +## Further Reading |
| 336 | + |
| 337 | +- [KVBM Kernels README](../../../lib/kvbm-kernels/README.md) - Detailed explanation of kernel layouts |
| 338 | +- [PyO3 Documentation](https://pyo3.rs/) - Python-Rust FFI framework |
| 339 | +- [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) - CUDA fundamentals |
0 commit comments