Skip to content
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
48cdb61
Int8Tensor migration
jcaip Dec 1, 2025
0b73aed
ruff fixes
jcaip Dec 1, 2025
1e49945
add init
jcaip Dec 1, 2025
669b6ee
fix ruff again
jcaip Dec 1, 2025
9071526
update
jcaip Dec 1, 2025
1539e0f
wip
jcaip Dec 2, 2025
d9a2b1b
Merge branch 'main' into jcaip/int8-tensor
jcaip Dec 3, 2025
673f228
undo update tests
jcaip Dec 3, 2025
739fd64
fix ruff
jcaip Dec 3, 2025
750db1a
fix varname
jcaip Dec 3, 2025
9410488
fix typing
jcaip Dec 3, 2025
45a3a76
add tests
jcaip Dec 3, 2025
4e2f09c
fix dtype
jcaip Dec 3, 2025
dd80cca
fix ci
jcaip Dec 3, 2025
7f73062
address granularity cr
jcaip Dec 4, 2025
ac6a2b6
update _choose_quant_func_and_quantize_tensor
jcaip Dec 4, 2025
f28df4a
make block size required attribute
jcaip Dec 4, 2025
328585e
made dtype required as well
jcaip Dec 4, 2025
ce4d568
address nits
jcaip Dec 4, 2025
a665d45
skip per tensor weight only test for now
jcaip Dec 4, 2025
0338016
add static quant
jcaip Dec 3, 2025
ee39691
add static quant
jcaip Dec 4, 2025
9eb0aa9
update
jcaip Dec 5, 2025
d4a1514
static quant working eager + compile
jcaip Dec 6, 2025
3cdea56
remove file
jcaip Dec 6, 2025
fa9022d
added asserts
jcaip Dec 6, 2025
8ce5cde
undo smoothquant change
jcaip Dec 6, 2025
6f64121
fix return
jcaip Dec 6, 2025
8ae921d
Merge branch 'main' into jcaip/static-quant-rebased
jcaip Dec 7, 2025
5b9e243
got smoothquant + int8 static working
jcaip Dec 8, 2025
7a0e38f
generalized smoothquat code
jcaip Dec 8, 2025
3d18edf
free tests
jcaip Dec 8, 2025
9e07f8b
fix static scale check
jcaip Dec 8, 2025
4274e02
update
jcaip Dec 8, 2025
b5309eb
address cr feedback
jcaip Dec 9, 2025
a732fee
Merge branch 'jcaip/static-quant-rebased' into jcaip/enable-smoothquant
jcaip Dec 9, 2025
0c23589
Merge branch 'main' into jcaip/enable-smoothquant
jcaip Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
)
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.quantization import quantize_
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.linear_activation_scale import (
WeightTensorWithLinearActivationScaleMetadata,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
)
from torchao.quantization.utils import (
compute_error as SQNR,
Expand Down Expand Up @@ -83,7 +85,10 @@ def setUpClass(cls):
@common_utils.parametrize(
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
# TODO: not sure if we should allow not passing scales as part of static config?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think it's fine

side note: we may need a separate API/flow for plain static quant without Smoothquant if needed.

Int8StaticActivationInt8WeightConfig(granularity=PerRow()),
Int8StaticActivationInt8WeightConfig(granularity=PerTensor()),
# Note: float8_static_activation_float8_weight is broken after recent PyTorch update.
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
],
Expand All @@ -101,7 +106,15 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):

# Step 1. Basic quantization
basic_model = deepcopy(m)
quantize_(basic_model, base_config)
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
quantize_(
basic_model,
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=base_config.granularity
),
)
else:
quantize_(basic_model, base_config)
out_basic = basic_model(*x)
loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item()

Expand Down
92 changes: 92 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase
Expand Down Expand Up @@ -221,5 +225,93 @@ def test_available_gpu_kernels(self):
).check_count("triton_poi_fused", 1).run(code[0])


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8StaticQuant(TorchAOIntegrationTestCase):
@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_static_activation_per_row_int8_weight_earger(self, granularity, dtype):
M, N, K = 32, 32, 32
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

model_static_quant = (
torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
)
model_dynamic_quant = copy.deepcopy(model_static_quant)

dynamic_config = Int8DynamicActivationInt8WeightConfig(
version=2, granularity=granularity
)
quantize_(model_dynamic_quant, dynamic_config)

dynamic_quantize_out = model_dynamic_quant(input_tensor)

int8_input = _choose_quant_func_and_quantize_tensor(
input_tensor, model_dynamic_quant.weight.act_quant_kwargs
)

static_config = Int8StaticActivationInt8WeightConfig(
scale=int8_input.scale.detach().clone(), granularity=granularity
)
quantize_(model_static_quant, static_config)

static_quantize_out = model_static_quant(input_tensor)
torch.testing.assert_close(dynamic_quantize_out, static_quantize_out)

@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_static_activation_per_row_int8_weight_compile(self, granularity, dtype):
# for compile, we can't compare dynamic vs static because we may get slightly different qparams
torch.compiler.reset()

M, N, K = 32, 32, 32
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
model_static_quant = copy.deepcopy(model)
model_dynamic_quant = copy.deepcopy(model)

model_out_baseline = model(input_tensor)

dynamic_config = Int8DynamicActivationInt8WeightConfig(
version=2, granularity=granularity
)
quantize_(model_dynamic_quant, dynamic_config)

dynamic_out_eager = model_dynamic_quant(input_tensor)
sqnr_dynamic_eager = compute_error(model_out_baseline, dynamic_out_eager)

model_dynamic_quant = torch.compile(model_dynamic_quant, fullgraph=True)

dynamic_out_compile = model_dynamic_quant(input_tensor)
sqnr_dynamic_compile = compute_error(model_out_baseline, dynamic_out_compile)

# we use eager scales to calculate
int8_input = _choose_quant_func_and_quantize_tensor(
input_tensor, model_dynamic_quant.weight.act_quant_kwargs
)

static_config = Int8StaticActivationInt8WeightConfig(
scale=int8_input.scale.detach().clone(),
granularity=granularity,
)
quantize_(model_static_quant, static_config)

static_out_eager = model_static_quant(input_tensor)
sqnr_static_eager = compute_error(model_out_baseline, static_out_eager)

model_static_quant = torch.compile(model_static_quant, fullgraph=True)

static_out_compile = model_dynamic_quant(input_tensor)
sqnr_static_compile = compute_error(model_out_baseline, static_out_compile)

assert (
sqnr_static_compile
== sqnr_static_eager
== sqnr_dynamic_compile
== sqnr_dynamic_eager
)


if __name__ == "__main__":
common_utils.run_tests()
20 changes: 19 additions & 1 deletion torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
)
from torchao.quantization.quant_api import (
_QUANTIZE_CONFIG_HANDLER,
Int8StaticActivationInt8WeightConfig,
_linear_extra_repr,
)
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
Expand Down Expand Up @@ -95,8 +99,18 @@ def _smooth_quant_transform(
else:
raise ValueError(f"Unexpected step: {step}")

if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
quant_kwargs = QuantizeTensorToInt8Kwargs(
granularity=base_config.granularity,
mapping_type=base_config.act_mapping_type,
)
else:
quant_kwargs = None

# Compute smoothed weight parameters
smoothing_factor = observed_linear.obs.calculate_qparams()
smoothing_factor, activation_scale = observed_linear.obs.calculate_qparams(
weight_quant_kwargs=quant_kwargs
)
weight = observed_linear.weight * smoothing_factor

# Create new linear layer
Expand All @@ -111,6 +125,9 @@ def _smooth_quant_transform(
linear.bias = observed_linear.bias

# Quantize weights
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
base_config.scale = activation_scale

base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)]
dummy_mod = DummyModule(weight)
quant_mod = base_config_handler(dummy_mod, base_config)
Expand All @@ -120,6 +137,7 @@ def _smooth_quant_transform(
qw = to_weight_tensor_with_linear_activation_scale_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not be using this, please check awq on how this should be implemented in the new stack:

assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the
# reciprocal of the `equalization_scale`
qw.act_pre_scale = 1.0 / equalization_scale

qw, smoothing_factor.to(qw.dtype)
)

linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)

Expand Down
25 changes: 19 additions & 6 deletions torchao/prototype/smoothquant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import torch
import torch.nn.functional as F

from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)


class SmoothQuantStep(str, Enum):
PREPARE = "prepare"
Expand Down Expand Up @@ -41,13 +45,14 @@ def forward(self, input: torch.Tensor):
self.inputs.append(input.to("cpu"))
return input

def calculate_qparams(self):
def calculate_qparams(self, weight_quant_kwargs=None):
assert self.inputs and len(self.inputs) > 0, (
"calibrate observer first by running model on exemplar data"
)
inputs = [inp.to(self.device) for inp in self.inputs]
acc = torch.cat(inputs, dim=0)
# Reshape if needed: [batch, seq, features] -> [batch*seq, features]
example_input_for_quantization = acc
if acc.ndim > 2:
acc = acc.view(-1, acc.shape[-1])

Expand All @@ -57,12 +62,20 @@ def calculate_qparams(self):

# Calculate smoothing factor
if self.alpha is None:
return torch.ones_like(x_abs_max)
smoothing_factor = torch.ones_like(x_abs_max)
else:
eps = torch.finfo(torch.float32).eps
smoothing_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)

eps = torch.finfo(torch.float32).eps
return torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)
if weight_quant_kwargs is not None:
quant_smooth_activation = _choose_quant_func_and_quantize_tensor(
example_input_for_quantization / smoothing_factor, weight_quant_kwargs
)
return smoothing_factor, quant_smooth_activation.scale
else:
return smoothing_factor, None


class SmoothQuantObservedLinear(torch.nn.Linear):
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
ModuleFqnToConfig,
Expand Down Expand Up @@ -150,6 +151,7 @@
"Int8DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int8StaticActivationInt8WeightConfig",
"Int4WeightOnlyConfig",
"Float8DynamicActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
Expand Down
93 changes: 87 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1590,10 +1591,6 @@ def get_weight_block_size(x):
)
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
else:
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)

assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor are supported"
)
Expand Down Expand Up @@ -1621,7 +1618,10 @@ def get_weight_block_size(x):

@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
def _int8_dynamic_activation_int8_weight_transform(
module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig
module: torch.nn.Module,
config: Int8DynamicActivationInt8WeightConfig,
*,
parameter_name="weight",
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
Expand All @@ -1634,7 +1634,88 @@ def _int8_dynamic_activation_int8_weight_transform(
module.weight, config
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


@dataclass
class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
"""
Configuration for applying float8 static symmetric quantization to

Args:
scale (torch.Tensor): The scale tensor for activation quantization.
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
"""

scale: torch.Tensor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Optional[torch.Tensor]

granularity: Granularity = PerRow()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
version: int = 1

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
)


@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
def _int8_static_activation_int8_weight_transform(
module: torch.nn.Module,
config: Int8StaticActivationInt8WeightConfig,
*,
parameter_name="weight",
):
assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor is supported currently"
)
assert config.act_mapping_type == MappingType.SYMMETRIC, (
"asymmetric static quant not supported currently"
)
assert hasattr(module, parameter_name), (
f"Expected module to have attribute `{parameter_name}` but not found"
)

if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

activation_granularity = config.granularity
weight_granularity = config.granularity

quantized_tensor = Int8Tensor.from_hp(
getattr(module, parameter_name),
granularity=weight_granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
),
activation_scale=config.scale.detach(),
)

setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down
Loading
Loading