2121//
2222// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
2323// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
24+ // - Added proper architecture check at both host and device level
2425//
2526
2627
@@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream,
9899 static_assert (std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, " Type must be 'half' or '__nv_bfloat16'" );
99100 assert (M_Global % 256 == 0 );
100101 assert (K_Global % 64 == 0 );
101- assert (N_Global>0 );
102+ assert (N_Global > 0 );
103+
104+ // Check GPU Compute Capability before proceeding
105+ int device, major, minor;
106+ CHECK_CUDA (cudaGetDevice (&device));
107+ CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
108+ CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
109+
110+ // Early exit with error for unsupported architectures
111+ if ((major < 7 ) || (major == 7 && minor < 5 )) {
112+ TORCH_CHECK (false , " Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
113+ " Your current device has SM" , major, minor, " which is not supported." );
114+ }
115+
116+ const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
117+ if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value) {
118+ TORCH_CHECK (false , " Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs." );
119+ }
102120
103121 // Work around to support more N shapes:
104122 size_t N_PowerOf2;
@@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream,
109127 if (N_Global>64 && N_Global<=128 ) N_PowerOf2 = 128 ;
110128 if (N_Global>128 ) N_PowerOf2 = ((N_Global-1 )/128 +1 ) * 128 ;
111129
112- // Check GPU Compute Capability
113- int device, major, minor;
114- CHECK_CUDA (cudaGetDevice (&device));
115- CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
116- CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
117- const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
118- if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
119- TORCH_CHECK (false , " Bfloat16 inputs are not supported for SM75" );
120- if ((major < 7 ) || (major == 7 && minor < 5 ))
121- TORCH_CHECK (false , " FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n " );
122-
123130 if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0 )) {
124131 // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
125132 if (Split_K == 1 ) {
@@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream,
136143 case 64 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
137144 case 128 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
138145 default : if (N_PowerOf2 % 128 != 0 ) {
139- TORCH_CHECK (false , " FP6LLM_API Error: Unsupported N dimension " , N_PowerOf2);
146+ TORCH_CHECK (false , " Quant-LLM Error: Unsupported N dimension " , N_PowerOf2);
140147 }
141148 Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
142149 }
@@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream,
149156 case 64 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
150157 case 128 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
151158 default : if (N_PowerOf2 % 128 != 0 ) {
152- TORCH_CHECK (false , " FP6LLM_API Error: Unsupported N dimension " , N_PowerOf2);
159+ TORCH_CHECK (false , " Quant-LLM Error: Unsupported N dimension " , N_PowerOf2);
153160 }
154161 Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
155162 }
@@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
210217 torch::Tensor _scales,
211218 int64_t splitK=1 )
212219{
220+ // Check GPU Compute Capability before proceeding
221+ int device, major, minor;
222+ CHECK_CUDA (cudaGetDevice (&device));
223+ CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
224+ CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
225+
226+ // Early exit with error for unsupported architectures
227+ if ((major < 7 ) || (major == 7 && minor < 5 )) {
228+ TORCH_CHECK (false , " Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
229+ " Your current device has SM" , major, minor, " which is not supported." );
230+ }
231+
232+ const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
233+ if (is_sm75_gpu && _in_feats.scalar_type () == at::ScalarType::BFloat16) {
234+ TORCH_CHECK (false , " Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs." );
235+ }
236+
213237 const int64_t NBITS = 1 + EXPONENT + MANTISSA;
214238 int num_in_feats = _in_feats.size (0 );
215239 int num_in_channels = _in_feats.size (1 );
0 commit comments