Skip to content

Commit d69697a

Browse files
authored
Merge pull request #112 from HazyResearch/ring-attn
Add Ring Attention Kernel
2 parents 87fa717 + 1506a7a commit d69697a

File tree

11 files changed

+3259
-0
lines changed

11 files changed

+3259
-0
lines changed

kernels/ring_attention/Makefile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
torch:
2+
python3 setup.py build_ext --inplace --verbose
3+
rm -rf build/
4+
5+
# If there is an error with libc10 not found, try:
6+
# import torch; print(torch.__path__[0])
7+
# Say whatever was printed above was PATH, run:
8+
# export LD_LIBRARY_PATH=PATH/lib:$LD_LIBRARY_PATH
9+
# For me it was:
10+
# export LD_LIBRARY_PATH=/usr/local/lib/python3.10/dist-packages/torch/lib:$LD_LIBRARY_PATH
11+
12+
clean:
13+
rm -rf build/ *.so __pycache__/ $(TARGET)

kernels/ring_attention/setup.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import subprocess
3+
from setuptools import setup
4+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
5+
6+
# Set environment variables
7+
thunderkittens_root = os.getenv('THUNDERKITTENS_ROOT', os.path.abspath(os.path.join(os.getcwd(), '../../')))
8+
python_include = subprocess.check_output(['python3', '-c', "import sysconfig; print(sysconfig.get_path('include'))"]).decode().strip()
9+
torch_include = subprocess.check_output(['python3', '-c', "import torch; from torch.utils.cpp_extension import include_paths; print(' '.join(['-I' + p for p in include_paths()]))"]).decode().strip()
10+
11+
# CUDA flags
12+
cuda_flags = [
13+
'-DNDEBUG',
14+
'-Xcompiler=-Wno-psabi',
15+
'-Xcompiler=-fno-strict-aliasing',
16+
'--expt-extended-lambda',
17+
'--expt-relaxed-constexpr',
18+
'-forward-unknown-to-host-compiler',
19+
'--use_fast_math',
20+
'-std=c++20',
21+
'-O3',
22+
'-Xnvlink=--verbose',
23+
'-Xptxas=--verbose',
24+
'-Xptxas=--warn-on-spills',
25+
f'-I{thunderkittens_root}/include',
26+
f'-I{thunderkittens_root}/prototype',
27+
f'-I{python_include}',
28+
'-DTORCH_COMPILE',
29+
'-DKITTENS_HOPPER', # assume H100 for ring attn
30+
'-arch=sm_90a', # assume H100 for ring attn
31+
] + torch_include.split()
32+
cpp_flags = [
33+
'-std=c++20',
34+
'-O3'
35+
]
36+
source_files = ['tk_ring_attention.cu']
37+
38+
setup(
39+
name='tk_ring_attention',
40+
ext_modules=[
41+
CUDAExtension(
42+
'tk_ring_attention',
43+
sources=source_files,
44+
extra_compile_args={'cxx' : cpp_flags,
45+
'nvcc' : cuda_flags},
46+
libraries=['cuda']
47+
)
48+
],
49+
cmdclass={
50+
'build_ext': BuildExtension
51+
}
52+
)

0 commit comments

Comments
 (0)