Skip to content

Commit 3c5f177

Browse files
authored
Fix CUDA version hardcoding to support CUDA 13+ dynamically (#26518)
## Description Fixes runtime library loading failures when building with CUDA 13 by replacing hardcoded CUDA 12 references with dynamic version detection. Related to #26516 which updates CUDA 13 build pipelines, but this PR fixes the Python runtime code that was still hardcoded to CUDA 12. ## Problem The build system correctly detects CUDA 13 via CMake, but the runtime Python code had CUDA 12 hardcoded in multiple locations, causing "CUDA 12 not found" errors on CUDA 13 systems. ## Solution Modified onnxruntime/__init__.py and setup.py to dynamically use the detected CUDA version instead of hardcoded "12" strings. ## Changes - Dynamic CUDA version extraction from build info - Library paths now use f-strings with cuda_major_version - Added CUDA 13 support to extras_require and dependency exclusions - Fixed TensorRT RTX package to use correct CUDA version - Updated version validation to accept CUDA 12+ - Fixed PyTorch compatibility checks to compare versions dynamically ## Impact - CUDA 13 builds now load correct libraries - Backward compatible with CUDA 12 - Forward compatible with future CUDA versions ## Testing Verified with CUDA 13.0 build that library paths resolve correctly and preload_dlls() loads CUDA 13 libraries without errors.
1 parent ea55c16 commit 3c5f177

File tree

2 files changed

+105
-42
lines changed

2 files changed

+105
-42
lines changed

onnxruntime/__init__.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
99
"""
1010

11+
import contextlib
12+
1113
__version__ = "1.24.0"
1214
__author__ = "Microsoft"
1315

@@ -133,14 +135,43 @@ def _get_package_root(package_name: str, directory_name: str | None = None):
133135
return None
134136

135137

138+
def _extract_cuda_major_version(version_str: str) -> str:
139+
"""Extract CUDA major version from version string (e.g., '12.1' -> '12').
140+
141+
Args:
142+
version_str: CUDA version string to parse
143+
144+
Returns:
145+
Major version as string, or "12" if parsing fails
146+
"""
147+
return version_str.split(".")[0] if version_str else "12"
148+
149+
150+
def _get_cufft_version(cuda_major: str) -> str:
151+
"""Get cufft library version based on CUDA major version.
152+
153+
Args:
154+
cuda_major: CUDA major version as string (e.g., "12", "13")
155+
156+
Returns:
157+
cufft version as string
158+
"""
159+
# cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12
160+
return "12" if cuda_major == "13" else "11"
161+
162+
136163
def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = True):
164+
# Dynamically determine CUDA major version from build info
165+
cuda_major_version = _extract_cuda_major_version(cuda_version)
166+
cufft_version = _get_cufft_version(cuda_major_version)
167+
137168
if is_windows:
138169
# Path is relative to site-packages directory.
139170
cuda_dll_paths = [
140-
("nvidia", "cublas", "bin", "cublasLt64_12.dll"),
141-
("nvidia", "cublas", "bin", "cublas64_12.dll"),
142-
("nvidia", "cufft", "bin", "cufft64_11.dll"),
143-
("nvidia", "cuda_runtime", "bin", "cudart64_12.dll"),
171+
("nvidia", "cublas", "bin", f"cublasLt64_{cuda_major_version}.dll"),
172+
("nvidia", "cublas", "bin", f"cublas64_{cuda_major_version}.dll"),
173+
("nvidia", "cufft", "bin", f"cufft64_{cufft_version}.dll"),
174+
("nvidia", "cuda_runtime", "bin", f"cudart64_{cuda_major_version}.dll"),
144175
]
145176
cudnn_dll_paths = [
146177
("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"),
@@ -154,12 +185,12 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru
154185
else: # Linux
155186
# cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
156187
cuda_dll_paths = [
157-
("nvidia", "cublas", "lib", "libcublasLt.so.12"),
158-
("nvidia", "cublas", "lib", "libcublas.so.12"),
159-
("nvidia", "cuda_nvrtc", "lib", "libnvrtc.so.12"),
188+
("nvidia", "cublas", "lib", f"libcublasLt.so.{cuda_major_version}"),
189+
("nvidia", "cublas", "lib", f"libcublas.so.{cuda_major_version}"),
190+
("nvidia", "cuda_nvrtc", "lib", f"libnvrtc.so.{cuda_major_version}"),
160191
("nvidia", "curand", "lib", "libcurand.so.10"),
161-
("nvidia", "cufft", "lib", "libcufft.so.11"),
162-
("nvidia", "cuda_runtime", "lib", "libcudart.so.12"),
192+
("nvidia", "cufft", "lib", f"libcufft.so.{cufft_version}"),
193+
("nvidia", "cuda_runtime", "lib", f"libcudart.so.{cuda_major_version}"),
163194
]
164195

165196
# Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux.
@@ -201,15 +232,17 @@ def print_debug_info():
201232

202233
if cuda_version:
203234
# Print version of installed packages that is related to CUDA or cuDNN DLLs.
235+
cuda_major = _extract_cuda_major_version(cuda_version)
236+
204237
packages = [
205238
"torch",
206-
"nvidia-cuda-runtime-cu12",
207-
"nvidia-cudnn-cu12",
208-
"nvidia-cublas-cu12",
209-
"nvidia-cufft-cu12",
210-
"nvidia-curand-cu12",
211-
"nvidia-cuda-nvrtc-cu12",
212-
"nvidia-nvjitlink-cu12",
239+
f"nvidia-cuda-runtime-cu{cuda_major}",
240+
f"nvidia-cudnn-cu{cuda_major}",
241+
f"nvidia-cublas-cu{cuda_major}",
242+
f"nvidia-cufft-cu{cuda_major}",
243+
f"nvidia-curand-cu{cuda_major}",
244+
f"nvidia-cuda-nvrtc-cu{cuda_major}",
245+
f"nvidia-nvjitlink-cu{cuda_major}",
213246
]
214247
for package in packages:
215248
directory_name = "nvidia" if package.startswith("nvidia-") else None
@@ -254,7 +287,7 @@ def is_target_dll(path: str):
254287

255288

256289
def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, directory=None):
257-
"""Preload CUDA 12.x and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
290+
"""Preload CUDA 12.x+ and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
258291
259292
When the installed PyTorch is compatible (using same major version of CUDA and cuDNN),
260293
there is no need to call this function if `import torch` is done before `import onnxruntime`.
@@ -289,30 +322,53 @@ def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, direc
289322
print("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.")
290323
print("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe.")
291324

292-
if not (cuda_version and cuda_version.startswith("12.")) and (cuda or cudnn):
293-
print(
294-
f"\033[33mWARNING: {package_name} is not built with CUDA 12.x support. "
295-
"Please install a version that supports CUDA 12.x, or call preload_dlls with cuda=False and cudnn=False.\033[0m"
296-
)
297-
return
298-
299-
if not (cuda_version and cuda_version.startswith("12.") and (cuda or cudnn)):
325+
# Check if CUDA version is supported (12.x or 13.x+)
326+
ort_cuda_major = None
327+
if cuda_version:
328+
try:
329+
ort_cuda_major = int(cuda_version.split(".")[0])
330+
if ort_cuda_major < 12 and (cuda or cudnn):
331+
print(
332+
f"\033[33mWARNING: {package_name} is built with CUDA {cuda_version}, which is not supported for preloading. "
333+
f"CUDA 12.x or newer is required. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
334+
)
335+
return
336+
except ValueError:
337+
print(
338+
f"\033[33mWARNING: Unable to parse CUDA version '{cuda_version}'. "
339+
"Skipping DLL preloading. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
340+
)
341+
return
342+
elif cuda or cudnn:
343+
# No CUDA version info available but CUDA/cuDNN preloading requested
300344
return
301345

302346
is_cuda_cudnn_imported_by_torch = False
303347

304348
if is_windows:
305349
torch_version = _get_package_version("torch")
306-
is_torch_for_cuda_12 = torch_version and "+cu12" in torch_version
350+
# Check if torch CUDA version matches onnxruntime CUDA version
351+
torch_cuda_major = None
352+
if torch_version and "+cu" in torch_version:
353+
with contextlib.suppress(ValueError):
354+
# Extract CUDA version from torch (e.g., "2.0.0+cu121" -> 12)
355+
cu_part = torch_version.split("+cu")[1]
356+
torch_cuda_major = int(cu_part[:2]) # First 2 digits are major version
357+
358+
is_torch_cuda_compatible = (
359+
torch_cuda_major == ort_cuda_major if (torch_cuda_major and ort_cuda_major) else False
360+
)
361+
307362
if "torch" in sys.modules:
308-
is_cuda_cudnn_imported_by_torch = is_torch_for_cuda_12
309-
if (torch_version and "+cu" in torch_version) and not is_torch_for_cuda_12:
363+
is_cuda_cudnn_imported_by_torch = is_torch_cuda_compatible
364+
if torch_cuda_major and ort_cuda_major and torch_cuda_major != ort_cuda_major:
310365
print(
311-
f"\033[33mWARNING: The installed PyTorch {torch_version} does not support CUDA 12.x. "
312-
f"Please install PyTorch for CUDA 12.x to be compatible with {package_name}.\033[0m"
366+
f"\033[33mWARNING: The installed PyTorch {torch_version} uses CUDA {torch_cuda_major}.x, "
367+
f"but {package_name} is built with CUDA {ort_cuda_major}.x. "
368+
f"Please install PyTorch for CUDA {ort_cuda_major}.x to be compatible.\033[0m"
313369
)
314370

315-
if is_torch_for_cuda_12 and directory is None:
371+
if is_torch_cuda_compatible and directory is None:
316372
torch_root = _get_package_root("torch", "torch")
317373
if torch_root:
318374
directory = os.path.join(torch_root, "lib")

setup.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,17 @@ def parse_arg_remove_string(argv, arg_name_equal):
5454
wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=")
5555

5656
cuda_version = None
57-
is_cuda_version_12 = False
57+
cuda_major_version = None
5858
rocm_version = None
5959
is_migraphx = False
6060
is_openvino = False
6161
is_qnn = False
6262
qnn_version = None
6363
# The following arguments are mutually exclusive
6464
if wheel_name_suffix == "gpu":
65-
# TODO: how to support multiple CUDA versions?
6665
cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=")
6766
if cuda_version:
68-
is_cuda_version_12 = cuda_version.startswith("12.")
67+
cuda_major_version = cuda_version.split(".")[0]
6968
elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"):
7069
is_migraphx = True
7170
package_name = "onnxruntime-migraphx"
@@ -222,20 +221,26 @@ def run(self):
222221
"libcuda.so.1",
223222
"libcublas.so.11",
224223
"libcublas.so.12",
224+
"libcublas.so.13",
225225
"libcublasLt.so.11",
226226
"libcublasLt.so.12",
227+
"libcublasLt.so.13",
227228
"libcudart.so.11.0",
228229
"libcudart.so.12",
230+
"libcudart.so.13",
229231
"libcudnn.so.8",
230232
"libcudnn.so.9",
231233
"libcufft.so.10",
232234
"libcufft.so.11",
233235
"libcurand.so.10",
234236
"libnvJitLink.so.12",
237+
"libnvJitLink.so.13",
235238
"libnvrtc.so.11.2", # A symlink to libnvrtc.so.11.8.89
236239
"libnvrtc.so.12",
240+
"libnvrtc.so.13",
237241
"libnvrtc-builtins.so.11",
238242
"libnvrtc-builtins.so.12",
243+
"libnvrtc-builtins.so.13",
239244
]
240245

241246
rocm_dependencies = [
@@ -783,8 +788,8 @@ def reformat_run_count(count_str):
783788

784789
# Adding CUDA Runtime as dependency for NV TensorRT RTX python wheel
785790
if package_name == "onnxruntime-trt-rtx":
786-
install_requires.append("nvidia-cuda-runtime-cu12~=12.0")
787-
cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=")
791+
major = cuda_major_version or "12" # Default to CUDA 12
792+
install_requires.append(f"nvidia-cuda-runtime-cu{major}~={major}.0")
788793

789794

790795
def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version, qnn_version):
@@ -823,16 +828,18 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
823828
save_build_and_package_info(package_name, version_number, cuda_version, rocm_version, qnn_version)
824829

825830
extras_require = {}
826-
if package_name == "onnxruntime-gpu" and is_cuda_version_12:
831+
if package_name == "onnxruntime-gpu" and cuda_major_version:
832+
# Determine cufft version: CUDA 13 uses cufft 12, CUDA 12 uses cufft 11
833+
cufft_version = "12.0" if cuda_major_version == "13" else "11.0"
827834
extras_require = {
828835
"cuda": [
829-
"nvidia-cuda-nvrtc-cu12~=12.0",
830-
"nvidia-cuda-runtime-cu12~=12.0",
831-
"nvidia-cufft-cu12~=11.0",
832-
"nvidia-curand-cu12~=10.0",
836+
f"nvidia-cuda-nvrtc-cu{cuda_major_version}~={cuda_major_version}.0",
837+
f"nvidia-cuda-runtime-cu{cuda_major_version}~={cuda_major_version}.0",
838+
f"nvidia-cufft-cu{cuda_major_version}~={cufft_version}",
839+
f"nvidia-curand-cu{cuda_major_version}~=10.0",
833840
],
834841
"cudnn": [
835-
"nvidia-cudnn-cu12~=9.0",
842+
f"nvidia-cudnn-cu{cuda_major_version}~=9.0",
836843
],
837844
}
838845

0 commit comments

Comments
 (0)