Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,23 +2106,28 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
Test QAT with `NVFP4FakeQuantizeConfig`.
"""
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
from torchao.prototype.qat import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizedLinear,
)

torch.manual_seed(self.SEED)
m = M().cuda()
baseline_model = copy.deepcopy(m)
quantize_(
baseline_model,
NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=use_per_tensor_scale
),
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=use_per_tensor_scale
)
quantize_(baseline_model, base_config)
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
step="prepare",
)
quantize_(m, qat_config)
self.assertEqual(type(m.linear1), NVFP4FakeQuantizedLinear)
self.assertEqual(type(m.linear2), NVFP4FakeQuantizedLinear)
self.assertEqual(type(m.sub.linear), NVFP4FakeQuantizedLinear)

# Compare prepared values
torch.manual_seed(self.SEED)
Expand All @@ -2132,6 +2137,18 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
sqnr = compute_error(out, baseline_out).item()
self.assertGreaterEqual(sqnr, float("inf"))

# Compare converted values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh we didn't compare convert results before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's tested here:

def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):

We just never explicitly checked it's using tensor subclasses after convert (tests still passed because QAT prepare mimics PTQ exactly)

quantize_(m, QATConfig(base_config, step="convert"))
self.assertEqual(type(m.linear1), torch.nn.Linear)
self.assertEqual(type(m.linear2), torch.nn.Linear)
self.assertEqual(type(m.sub.linear), torch.nn.Linear)
self.assertEqual(type(m.linear1.weight), NVFP4Tensor)
self.assertEqual(type(m.linear2.weight), NVFP4Tensor)
self.assertEqual(type(m.sub.linear.weight), NVFP4Tensor)
out = m(*x)
sqnr = compute_error(out, baseline_out).item()
self.assertGreaterEqual(sqnr, float("inf"))

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
Expand Down
16 changes: 16 additions & 0 deletions torchao/prototype/qat/nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
else:
return fq

def to_linear(self) -> torch.nn.Linear:
new_linear = torch.nn.Linear(
self.in_features,
self.out_features,
self.bias is not None,
device=self.weight.device,
dtype=self.weight.dtype,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if self.weight.device != torch.device("meta"):
new_linear.weight = self.weight
new_linear.bias = self.bias
return new_linear

@classmethod
def from_linear(
cls,
Expand Down
27 changes: 15 additions & 12 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ def _qat_config_transform(
modules to the corresponding built-in `torch.nn.Module`s, then apply the
base config directly to quantize the module.
"""
# TODO: rewrite this using a registration API so
# specific quantization schemes do not leak here
from torchao.prototype.qat import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizedLinear,
)

# Prepare step
# Swap nn.Linear -> FakeQuantizedLinear
# Swap nn.Embedding -> FakeQuantizedEmbedding
Expand All @@ -210,13 +217,6 @@ def _qat_config_transform(
act_config = config.activation_config
weight_config = config.weight_config
if isinstance(module, torch.nn.Linear):
# TODO: rewrite this using a registration API so
# specific quantization schemes do not leak here
from torchao.prototype.qat import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizedLinear,
)

if isinstance(weight_config, NVFP4FakeQuantizeConfig):
assert act_config is None or isinstance(
act_config, NVFP4FakeQuantizeConfig
Expand Down Expand Up @@ -245,17 +245,20 @@ def _qat_config_transform(
assert config.weight_config is None, "unexpected `weight_config`"

# Ignore unrelated modules
if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
if not isinstance(
module,
(FakeQuantizedLinear, FakeQuantizedEmbedding, NVFP4FakeQuantizedLinear),
):
return module

# Optionally pass custom scales and zero points to base config handler
# This is only for range learning and only applies to weights
kwargs = {}
has_custom_scale_and_zero_point = False
weight_config = module.weight_fake_quantizer.config
if (
isinstance(weight_config, IntxFakeQuantizeConfig)
and weight_config.range_learning
hasattr(module, "weight_fake_quantizer")
and isinstance(module.weight_fake_quantizer.config, IntxFakeQuantizeConfig)
and module.weight_fake_quantizer.config.range_learning
):
kwargs["custom_scale"] = module.weight_fake_quantizer.scale
kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point
Expand All @@ -265,7 +268,7 @@ def _qat_config_transform(
# Swap FakeQuantizedEmbedding -> nn.Embedding
# Then apply the base config's transform function to quantize the model
# If there is no base config, then simply perform the module swap
if isinstance(module, FakeQuantizedLinear):
if isinstance(module, (FakeQuantizedLinear, NVFP4FakeQuantizedLinear)):
module = module.to_linear()
elif isinstance(module, FakeQuantizedEmbedding):
module = module.to_embedding()
Expand Down
Loading