1717 Float8DynamicActivationFloat8SemiSparseWeightConfig ,
1818 Int4WeightOnlyConfig ,
1919 LNLinearSigmoid ,
20- RMSNorm ,
21- RMSNormLinearActivation ,
2220 SemiSparseWeightConfig ,
2321 ToyLinearModel ,
24- TransformerBlock ,
2522 clean_caches ,
2623 create_model_and_input ,
2724 generate_results_csv ,
@@ -165,61 +162,6 @@ def test_ln_linear_sigmoid(self):
165162 torch .all ((out >= 0 ) & (out <= 1 ))
166163 ) # Check sigmoid output range
167164
168- def test_rms_norm (self ):
169- # Test RMSNorm
170- rms_norm = RMSNorm (dim = 64 )
171- x = torch .randn (16 , 64 )
172- out = rms_norm (x )
173- self .assertEqual (out .shape , (16 , 64 ))
174-
175- # Test with different eps
176- rms_norm = RMSNorm (dim = 64 , eps = 1e-5 )
177- out = rms_norm (x )
178- self .assertEqual (out .shape , (16 , 64 ))
179-
180- def test_rms_norm_linear_activation (self ):
181- # Test with default GELU activation
182- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
183- x = torch .randn (16 , 64 )
184- out = model (x )
185- self .assertEqual (out .shape , (16 , 32 ))
186- self .assertEqual (out .dtype , torch .float32 )
187-
188- # Test with ReLU activation
189- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu" )
190- out = model (x )
191- self .assertEqual (out .shape , (16 , 32 ))
192- self .assertTrue (torch .all (out >= 0 )) # Check ReLU output range
193-
194- # Test with SiLU activation
195- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu" )
196- out = model (x )
197- self .assertEqual (out .shape , (16 , 32 ))
198-
199- # Test with invalid activation
200- with self .assertRaises (ValueError ):
201- RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid" )
202-
203- def test_transformer_block (self ):
204- # Test with default parameters
205- model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
206- x = torch .randn (16 , 16 , 64 ) # [batch_size, seq_len, hidden_dim]
207- out = model (x )
208- self .assertEqual (out .shape , (16 , 16 , 64 ))
209- self .assertEqual (out .dtype , torch .float32 )
210-
211- # Test with different parameters
212- model = TransformerBlock (hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32 )
213- x = torch .randn (8 , 32 , 128 )
214- out = model (x )
215- self .assertEqual (out .shape , (8 , 32 , 128 ))
216-
217- # Test with different head dimensions
218- model = TransformerBlock (hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32 )
219- x = torch .randn (4 , 8 , 96 )
220- out = model (x )
221- self .assertEqual (out .shape , (4 , 8 , 96 ))
222-
223165 def test_create_model_and_input (self ):
224166 m , k , n = 16 , 64 , 32
225167 model , input_data = create_model_and_input (
@@ -244,63 +186,6 @@ def test_create_model_and_input(self):
244186 self .assertIsInstance (model , LNLinearSigmoid )
245187 self .assertEqual (input_data .shape , (m , k ))
246188
247- # Test RMSNormLinearActivation
248- model , input_data = create_model_and_input (
249- model_type = "rms_norm_linear_activation" ,
250- m = m ,
251- k = k ,
252- n = n ,
253- high_precision_dtype = torch .float32 ,
254- device = "cpu" ,
255- )
256- self .assertIsInstance (model , RMSNormLinearActivation )
257- self .assertEqual (input_data .shape , (m , k ))
258-
259- # Test TransformerBlock
260- model , input_data = create_model_and_input (
261- model_type = "transformer_block" ,
262- m = m ,
263- k = k ,
264- n = n , # n is not used for transformer_block
265- high_precision_dtype = torch .float32 ,
266- device = "cpu" ,
267- )
268- self .assertIsInstance (model , TransformerBlock )
269- self .assertEqual (input_data .shape , (m , 16 , k )) # [batch_size, seq_len, hidden_dim]
270-
271- def test_quantization_on_models (self ):
272- # Test quantization on RMSNormLinearActivation
273- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
274- x = torch .randn (16 , 64 )
275-
276- # Test with Int8WeightOnlyConfig
277- config = string_to_config (quantization = "int8wo" , sparsity = None )
278- if config is not None :
279- # Skip quantization test if torchao.quantization.quantize is not available
280- try :
281- from torchao .quantization import quantize
282- quantized_model = quantize (model , config )
283- out = quantized_model (x )
284- self .assertEqual (out .shape , (16 , 32 ))
285- except ImportError :
286- print ("Skipping quantization test: torchao.quantization.quantize not available" )
287-
288- # Test quantization on TransformerBlock
289- model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
290- x = torch .randn (16 , 16 , 64 )
291-
292- # Test with Int8WeightOnlyConfig
293- config = string_to_config (quantization = "int8wo" , sparsity = None )
294- if config is not None :
295- # Skip quantization test if torchao.quantization.quantize is not available
296- try :
297- from torchao .quantization import quantize
298- quantized_model = quantize (model , config )
299- out = quantized_model (x )
300- self .assertEqual (out .shape , (16 , 16 , 64 ))
301- except ImportError :
302- print ("Skipping quantization test: torchao.quantization.quantize not available" )
303-
304189 def test_generate_results_csv (self ):
305190 results = [
306191 BenchmarkResult (
0 commit comments