|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | +from setuptools import setup |
| 4 | +from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
| 5 | + |
| 6 | +# Set environment variables |
| 7 | +thunderkittens_root = os.getenv('THUNDERKITTENS_ROOT', os.path.abspath(os.path.join(os.getcwd(), '../../'))) |
| 8 | +python_include = subprocess.check_output(['python3', '-c', "import sysconfig; print(sysconfig.get_path('include'))"]).decode().strip() |
| 9 | +torch_include = subprocess.check_output(['python3', '-c', "import torch; from torch.utils.cpp_extension import include_paths; print(' '.join(['-I' + p for p in include_paths()]))"]).decode().strip() |
| 10 | + |
| 11 | +# CUDA flags |
| 12 | +cuda_flags = [ |
| 13 | + '-DNDEBUG', |
| 14 | + '-Xcompiler=-Wno-psabi', |
| 15 | + '-Xcompiler=-fno-strict-aliasing', |
| 16 | + '--expt-extended-lambda', |
| 17 | + '--expt-relaxed-constexpr', |
| 18 | + '-forward-unknown-to-host-compiler', |
| 19 | + '--use_fast_math', |
| 20 | + '-std=c++20', |
| 21 | + '-O3', |
| 22 | + '-Xnvlink=--verbose', |
| 23 | + '-Xptxas=--verbose', |
| 24 | + '-Xptxas=--warn-on-spills', |
| 25 | + f'-I{thunderkittens_root}/include', |
| 26 | + f'-I{thunderkittens_root}/prototype', |
| 27 | + f'-I{python_include}', |
| 28 | + '-DTORCH_COMPILE', |
| 29 | + '-DKITTENS_HOPPER', # assume H100 for ring attn |
| 30 | + '-arch=sm_90a', # assume H100 for ring attn |
| 31 | +] + torch_include.split() |
| 32 | +cpp_flags = [ |
| 33 | + '-std=c++20', |
| 34 | + '-O3' |
| 35 | +] |
| 36 | +source_files = ['tk_ring_attention.cu'] |
| 37 | + |
| 38 | +setup( |
| 39 | + name='tk_ring_attention', |
| 40 | + ext_modules=[ |
| 41 | + CUDAExtension( |
| 42 | + 'tk_ring_attention', |
| 43 | + sources=source_files, |
| 44 | + extra_compile_args={'cxx' : cpp_flags, |
| 45 | + 'nvcc' : cuda_flags}, |
| 46 | + libraries=['cuda'] |
| 47 | + ) |
| 48 | + ], |
| 49 | + cmdclass={ |
| 50 | + 'build_ext': BuildExtension |
| 51 | + } |
| 52 | +) |
0 commit comments