diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 27d6d8bb85..6667a2e907 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -2106,23 +2106,28 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): Test QAT with `NVFP4FakeQuantizeConfig`. """ from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig - from torchao.prototype.qat import NVFP4FakeQuantizeConfig + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + from torchao.prototype.qat import ( + NVFP4FakeQuantizeConfig, + NVFP4FakeQuantizedLinear, + ) torch.manual_seed(self.SEED) m = M().cuda() baseline_model = copy.deepcopy(m) - quantize_( - baseline_model, - NVFP4DynamicActivationNVFP4WeightConfig( - use_dynamic_per_tensor_scale=use_per_tensor_scale - ), + base_config = NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=use_per_tensor_scale ) + quantize_(baseline_model, base_config) qat_config = QATConfig( activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), step="prepare", ) quantize_(m, qat_config) + self.assertEqual(type(m.linear1), NVFP4FakeQuantizedLinear) + self.assertEqual(type(m.linear2), NVFP4FakeQuantizedLinear) + self.assertEqual(type(m.sub.linear), NVFP4FakeQuantizedLinear) # Compare prepared values torch.manual_seed(self.SEED) @@ -2132,6 +2137,18 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): sqnr = compute_error(out, baseline_out).item() self.assertGreaterEqual(sqnr, float("inf")) + # Compare converted values + quantize_(m, QATConfig(base_config, step="convert")) + self.assertEqual(type(m.linear1), torch.nn.Linear) + self.assertEqual(type(m.linear2), torch.nn.Linear) + self.assertEqual(type(m.sub.linear), torch.nn.Linear) + self.assertEqual(type(m.linear1.weight), NVFP4Tensor) + self.assertEqual(type(m.linear2.weight), NVFP4Tensor) + self.assertEqual(type(m.sub.linear.weight), NVFP4Tensor) + out = m(*x) + sqnr = compute_error(out, baseline_out).item() + self.assertGreaterEqual(sqnr, float("inf")) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" diff --git a/torchao/prototype/qat/nvfp4.py b/torchao/prototype/qat/nvfp4.py index 0a2316621f..9d635faed6 100644 --- a/torchao/prototype/qat/nvfp4.py +++ b/torchao/prototype/qat/nvfp4.py @@ -162,6 +162,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: return fq + def to_linear(self) -> torch.nn.Linear: + new_linear = torch.nn.Linear( + self.in_features, + self.out_features, + self.bias is not None, + device=self.weight.device, + dtype=self.weight.dtype, + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if self.weight.device != torch.device("meta"): + new_linear.weight = self.weight + new_linear.bias = self.bias + return new_linear + @classmethod def from_linear( cls, diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 551a6d5da0..8720eb6676 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -198,6 +198,13 @@ def _qat_config_transform( modules to the corresponding built-in `torch.nn.Module`s, then apply the base config directly to quantize the module. """ + # TODO: rewrite this using a registration API so + # specific quantization schemes do not leak here + from torchao.prototype.qat import ( + NVFP4FakeQuantizeConfig, + NVFP4FakeQuantizedLinear, + ) + # Prepare step # Swap nn.Linear -> FakeQuantizedLinear # Swap nn.Embedding -> FakeQuantizedEmbedding @@ -210,13 +217,6 @@ def _qat_config_transform( act_config = config.activation_config weight_config = config.weight_config if isinstance(module, torch.nn.Linear): - # TODO: rewrite this using a registration API so - # specific quantization schemes do not leak here - from torchao.prototype.qat import ( - NVFP4FakeQuantizeConfig, - NVFP4FakeQuantizedLinear, - ) - if isinstance(weight_config, NVFP4FakeQuantizeConfig): assert act_config is None or isinstance( act_config, NVFP4FakeQuantizeConfig @@ -245,17 +245,20 @@ def _qat_config_transform( assert config.weight_config is None, "unexpected `weight_config`" # Ignore unrelated modules - if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)): + if not isinstance( + module, + (FakeQuantizedLinear, FakeQuantizedEmbedding, NVFP4FakeQuantizedLinear), + ): return module # Optionally pass custom scales and zero points to base config handler # This is only for range learning and only applies to weights kwargs = {} has_custom_scale_and_zero_point = False - weight_config = module.weight_fake_quantizer.config if ( - isinstance(weight_config, IntxFakeQuantizeConfig) - and weight_config.range_learning + hasattr(module, "weight_fake_quantizer") + and isinstance(module.weight_fake_quantizer.config, IntxFakeQuantizeConfig) + and module.weight_fake_quantizer.config.range_learning ): kwargs["custom_scale"] = module.weight_fake_quantizer.scale kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point @@ -265,7 +268,7 @@ def _qat_config_transform( # Swap FakeQuantizedEmbedding -> nn.Embedding # Then apply the base config's transform function to quantize the model # If there is no base config, then simply perform the module swap - if isinstance(module, FakeQuantizedLinear): + if isinstance(module, (FakeQuantizedLinear, NVFP4FakeQuantizedLinear)): module = module.to_linear() elif isinstance(module, FakeQuantizedEmbedding): module = module.to_embedding()