88or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
99"""
1010
11+ import contextlib
12+
1113__version__ = "1.24.0"
1214__author__ = "Microsoft"
1315
@@ -133,14 +135,43 @@ def _get_package_root(package_name: str, directory_name: str | None = None):
133135 return None
134136
135137
138+ def _extract_cuda_major_version (version_str : str ) -> str :
139+ """Extract CUDA major version from version string (e.g., '12.1' -> '12').
140+
141+ Args:
142+ version_str: CUDA version string to parse
143+
144+ Returns:
145+ Major version as string, or "12" if parsing fails
146+ """
147+ return version_str .split ("." )[0 ] if version_str else "12"
148+
149+
150+ def _get_cufft_version (cuda_major : str ) -> str :
151+ """Get cufft library version based on CUDA major version.
152+
153+ Args:
154+ cuda_major: CUDA major version as string (e.g., "12", "13")
155+
156+ Returns:
157+ cufft version as string
158+ """
159+ # cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12
160+ return "12" if cuda_major == "13" else "11"
161+
162+
136163def _get_nvidia_dll_paths (is_windows : bool , cuda : bool = True , cudnn : bool = True ):
164+ # Dynamically determine CUDA major version from build info
165+ cuda_major_version = _extract_cuda_major_version (cuda_version )
166+ cufft_version = _get_cufft_version (cuda_major_version )
167+
137168 if is_windows :
138169 # Path is relative to site-packages directory.
139170 cuda_dll_paths = [
140- ("nvidia" , "cublas" , "bin" , "cublasLt64_12 .dll" ),
141- ("nvidia" , "cublas" , "bin" , "cublas64_12 .dll" ),
142- ("nvidia" , "cufft" , "bin" , "cufft64_11 .dll" ),
143- ("nvidia" , "cuda_runtime" , "bin" , "cudart64_12 .dll" ),
171+ ("nvidia" , "cublas" , "bin" , f"cublasLt64_ { cuda_major_version } .dll" ),
172+ ("nvidia" , "cublas" , "bin" , f"cublas64_ { cuda_major_version } .dll" ),
173+ ("nvidia" , "cufft" , "bin" , f"cufft64_ { cufft_version } .dll" ),
174+ ("nvidia" , "cuda_runtime" , "bin" , f"cudart64_ { cuda_major_version } .dll" ),
144175 ]
145176 cudnn_dll_paths = [
146177 ("nvidia" , "cudnn" , "bin" , "cudnn_engines_runtime_compiled64_9.dll" ),
@@ -154,12 +185,12 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru
154185 else : # Linux
155186 # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
156187 cuda_dll_paths = [
157- ("nvidia" , "cublas" , "lib" , "libcublasLt.so.12 " ),
158- ("nvidia" , "cublas" , "lib" , "libcublas.so.12 " ),
159- ("nvidia" , "cuda_nvrtc" , "lib" , "libnvrtc.so.12 " ),
188+ ("nvidia" , "cublas" , "lib" , f "libcublasLt.so.{ cuda_major_version } " ),
189+ ("nvidia" , "cublas" , "lib" , f "libcublas.so.{ cuda_major_version } " ),
190+ ("nvidia" , "cuda_nvrtc" , "lib" , f "libnvrtc.so.{ cuda_major_version } " ),
160191 ("nvidia" , "curand" , "lib" , "libcurand.so.10" ),
161- ("nvidia" , "cufft" , "lib" , "libcufft.so.11 " ),
162- ("nvidia" , "cuda_runtime" , "lib" , "libcudart.so.12 " ),
192+ ("nvidia" , "cufft" , "lib" , f "libcufft.so.{ cufft_version } " ),
193+ ("nvidia" , "cuda_runtime" , "lib" , f "libcudart.so.{ cuda_major_version } " ),
163194 ]
164195
165196 # Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux.
@@ -201,15 +232,17 @@ def print_debug_info():
201232
202233 if cuda_version :
203234 # Print version of installed packages that is related to CUDA or cuDNN DLLs.
235+ cuda_major = _extract_cuda_major_version (cuda_version )
236+
204237 packages = [
205238 "torch" ,
206- "nvidia-cuda-runtime-cu12 " ,
207- "nvidia-cudnn-cu12 " ,
208- "nvidia-cublas-cu12 " ,
209- "nvidia-cufft-cu12 " ,
210- "nvidia-curand-cu12 " ,
211- "nvidia-cuda-nvrtc-cu12 " ,
212- "nvidia-nvjitlink-cu12 " ,
239+ f "nvidia-cuda-runtime-cu { cuda_major } " ,
240+ f "nvidia-cudnn-cu { cuda_major } " ,
241+ f "nvidia-cublas-cu { cuda_major } " ,
242+ f "nvidia-cufft-cu { cuda_major } " ,
243+ f "nvidia-curand-cu { cuda_major } " ,
244+ f "nvidia-cuda-nvrtc-cu { cuda_major } " ,
245+ f "nvidia-nvjitlink-cu { cuda_major } " ,
213246 ]
214247 for package in packages :
215248 directory_name = "nvidia" if package .startswith ("nvidia-" ) else None
@@ -254,7 +287,7 @@ def is_target_dll(path: str):
254287
255288
256289def preload_dlls (cuda : bool = True , cudnn : bool = True , msvc : bool = True , directory = None ):
257- """Preload CUDA 12.x and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
290+ """Preload CUDA 12.x+ and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
258291
259292 When the installed PyTorch is compatible (using same major version of CUDA and cuDNN),
260293 there is no need to call this function if `import torch` is done before `import onnxruntime`.
@@ -289,30 +322,53 @@ def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, direc
289322 print ("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure." )
290323 print ("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe." )
291324
292- if not (cuda_version and cuda_version .startswith ("12." )) and (cuda or cudnn ):
293- print (
294- f"\033 [33mWARNING: { package_name } is not built with CUDA 12.x support. "
295- "Please install a version that supports CUDA 12.x, or call preload_dlls with cuda=False and cudnn=False.\033 [0m"
296- )
297- return
298-
299- if not (cuda_version and cuda_version .startswith ("12." ) and (cuda or cudnn )):
325+ # Check if CUDA version is supported (12.x or 13.x+)
326+ ort_cuda_major = None
327+ if cuda_version :
328+ try :
329+ ort_cuda_major = int (cuda_version .split ("." )[0 ])
330+ if ort_cuda_major < 12 and (cuda or cudnn ):
331+ print (
332+ f"\033 [33mWARNING: { package_name } is built with CUDA { cuda_version } , which is not supported for preloading. "
333+ f"CUDA 12.x or newer is required. Call preload_dlls with cuda=False and cudnn=False.\033 [0m"
334+ )
335+ return
336+ except ValueError :
337+ print (
338+ f"\033 [33mWARNING: Unable to parse CUDA version '{ cuda_version } '. "
339+ "Skipping DLL preloading. Call preload_dlls with cuda=False and cudnn=False.\033 [0m"
340+ )
341+ return
342+ elif cuda or cudnn :
343+ # No CUDA version info available but CUDA/cuDNN preloading requested
300344 return
301345
302346 is_cuda_cudnn_imported_by_torch = False
303347
304348 if is_windows :
305349 torch_version = _get_package_version ("torch" )
306- is_torch_for_cuda_12 = torch_version and "+cu12" in torch_version
350+ # Check if torch CUDA version matches onnxruntime CUDA version
351+ torch_cuda_major = None
352+ if torch_version and "+cu" in torch_version :
353+ with contextlib .suppress (ValueError ):
354+ # Extract CUDA version from torch (e.g., "2.0.0+cu121" -> 12)
355+ cu_part = torch_version .split ("+cu" )[1 ]
356+ torch_cuda_major = int (cu_part [:2 ]) # First 2 digits are major version
357+
358+ is_torch_cuda_compatible = (
359+ torch_cuda_major == ort_cuda_major if (torch_cuda_major and ort_cuda_major ) else False
360+ )
361+
307362 if "torch" in sys .modules :
308- is_cuda_cudnn_imported_by_torch = is_torch_for_cuda_12
309- if ( torch_version and "+cu" in torch_version ) and not is_torch_for_cuda_12 :
363+ is_cuda_cudnn_imported_by_torch = is_torch_cuda_compatible
364+ if torch_cuda_major and ort_cuda_major and torch_cuda_major != ort_cuda_major :
310365 print (
311- f"\033 [33mWARNING: The installed PyTorch { torch_version } does not support CUDA 12.x. "
312- f"Please install PyTorch for CUDA 12.x to be compatible with { package_name } .\033 [0m"
366+ f"\033 [33mWARNING: The installed PyTorch { torch_version } uses CUDA { torch_cuda_major } .x, "
367+ f"but { package_name } is built with CUDA { ort_cuda_major } .x. "
368+ f"Please install PyTorch for CUDA { ort_cuda_major } .x to be compatible.\033 [0m"
313369 )
314370
315- if is_torch_for_cuda_12 and directory is None :
371+ if is_torch_cuda_compatible and directory is None :
316372 torch_root = _get_package_root ("torch" , "torch" )
317373 if torch_root :
318374 directory = os .path .join (torch_root , "lib" )
0 commit comments