1717 PackedLinearInt8DynamicActivationIntxWeightLayout ,
1818)
1919from torchao .experimental .q_dq_layout import QDQLayout
20- from torchao .experimental .quant_api import int8_dynamic_activation_intx_weight
20+ from torchao .experimental .quant_api import (
21+ Int8DynamicActivationIntxWeightConfig ,
22+ replace_q_dq_with_torchao_quantized_linear_ops ,
23+ )
2124from torchao .quantization .granularity import PerGroup , PerRow
2225from torchao .quantization .quant_api import quantize_
2326from torchao .utils import unwrap_tensor_subclass
@@ -79,7 +82,7 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
7982 quantized_model = copy .deepcopy (model )
8083 quantize_ (
8184 quantized_model ,
82- int8_dynamic_activation_intx_weight (
85+ Int8DynamicActivationIntxWeightConfig (
8386 weight_dtype = weight_dtype ,
8487 granularity = granularity ,
8588 has_weight_zeros = has_weight_zeros ,
@@ -91,7 +94,7 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
9194 quantized_model_reference = copy .deepcopy (model )
9295 quantize_ (
9396 quantized_model_reference ,
94- int8_dynamic_activation_intx_weight (
97+ Int8DynamicActivationIntxWeightConfig (
9598 weight_dtype = weight_dtype ,
9699 granularity = granularity ,
97100 has_weight_zeros = has_weight_zeros ,
@@ -124,7 +127,7 @@ def test_accuracy_aten(self):
124127 quantized_model = copy .deepcopy (model )
125128 quantize_ (
126129 quantized_model ,
127- int8_dynamic_activation_intx_weight (
130+ Int8DynamicActivationIntxWeightConfig (
128131 weight_dtype = weight_dtype ,
129132 granularity = granularity ,
130133 has_weight_zeros = has_weight_zeros ,
@@ -136,7 +139,7 @@ def test_accuracy_aten(self):
136139 quantized_model_reference = copy .deepcopy (model )
137140 quantize_ (
138141 quantized_model_reference ,
139- int8_dynamic_activation_intx_weight (
142+ Int8DynamicActivationIntxWeightConfig (
140143 weight_dtype = weight_dtype ,
141144 granularity = granularity ,
142145 has_weight_zeros = has_weight_zeros ,
@@ -183,7 +186,7 @@ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
183186
184187 quantize_ (
185188 model ,
186- int8_dynamic_activation_intx_weight (
189+ Int8DynamicActivationIntxWeightConfig (
187190 weight_dtype = weight_dtype ,
188191 granularity = granularity ,
189192 has_weight_zeros = has_weight_zeros ,
@@ -245,7 +248,7 @@ def test_export_dynamic_shape_PackedLinearInt8DynamicActivationIntxWeightLayout(
245248
246249 quantize_ (
247250 model ,
248- int8_dynamic_activation_intx_weight (
251+ Int8DynamicActivationIntxWeightConfig (
249252 weight_dtype = weight_dtype ,
250253 granularity = granularity ,
251254 has_weight_zeros = has_weight_zeros ,
@@ -278,7 +281,7 @@ def test_export_QDQLayout(self):
278281
279282 quantize_ (
280283 model ,
281- int8_dynamic_activation_intx_weight (
284+ Int8DynamicActivationIntxWeightConfig (
282285 weight_dtype = weight_dtype ,
283286 granularity = granularity ,
284287 has_weight_zeros = has_weight_zeros ,
@@ -304,6 +307,76 @@ def test_export_QDQLayout(self):
304307 exported .graph_module .code
305308 )
306309
310+ def test_replace_q_dq_with_torchao_quantized_linear_ops (self ):
311+ layers = [
312+ torch .nn .Linear (256 , 128 , bias = True ),
313+ torch .nn .Linear (128 , 64 , bias = False ),
314+ torch .nn .Linear (64 , 32 , bias = True ),
315+ ]
316+ model = torch .nn .Sequential (* layers )
317+ activations = torch .randn (2 , 1 , 256 , dtype = torch .float32 )
318+ quantize_ (
319+ model ,
320+ Int8DynamicActivationIntxWeightConfig (
321+ weight_dtype = torch .int4 ,
322+ granularity = PerGroup (64 ),
323+ has_weight_zeros = True ,
324+ layout = QDQLayout (),
325+ ),
326+ lambda m , fqn : fqn == "0" ,
327+ )
328+ quantize_ (
329+ model ,
330+ Int8DynamicActivationIntxWeightConfig (
331+ weight_dtype = torch .int3 ,
332+ granularity = PerRow (),
333+ has_weight_zeros = False ,
334+ layout = QDQLayout (),
335+ ),
336+ lambda m , fqn : fqn == "1" ,
337+ )
338+ quantize_ (
339+ model ,
340+ Int8DynamicActivationIntxWeightConfig (
341+ weight_dtype = torch .int5 ,
342+ granularity = PerGroup (32 ),
343+ has_weight_zeros = False ,
344+ layout = QDQLayout (),
345+ ),
346+ lambda m , fqn : fqn == "2" ,
347+ )
348+
349+ eager_results = model (activations )
350+
351+ unwrap_tensor_subclass (model )
352+ exported = torch .export .export (model , (activations ,), strict = True )
353+ exported = replace_q_dq_with_torchao_quantized_linear_ops (exported )
354+
355+ # We should not find pack op because it gets constant folded
356+ FileCheck ().check_not ("torch.ops.torchao._pack_8bit_act" ).run (
357+ exported .graph_module .code
358+ )
359+
360+ # We should find 3 torchao linear ops
361+ FileCheck ().check_count (
362+ "torch.ops.torchao._linear_8bit_act_" , count = 3 , exactly = True
363+ ).run (exported .graph_module .code )
364+
365+ # We should not find Q/DQ ops
366+ FileCheck ().check_not ("torch.ops.quant.quantize_affine.default" ).run (
367+ exported .graph_module .code
368+ )
369+ FileCheck ().check_not ("torch.ops.quant.dequantize_affine.default" ).run (
370+ exported .graph_module .code
371+ )
372+ FileCheck ().check_not ("torch.ops.quant.choose_qparams_affine.default" ).run (
373+ exported .graph_module .code
374+ )
375+
376+ # Numerics should match
377+ exported_results = exported .module ()(activations )
378+ self .assertTrue (torch .allclose (exported_results , eager_results ))
379+
307380
308381if __name__ == "__main__" :
309382 unittest .main ()
0 commit comments