Skip to content

Commit f99105a

Browse files
authored
Add int8 static quantization workflow (#3442)
* Int8Tensor migration Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: * ruff fixes * add init * fix ruff again * update * wip * undo update tests * fix ruff * fix varname * fix typing * add tests * fix dtype * fix ci * address granularity cr * update _choose_quant_func_and_quantize_tensor * make block size required attribute * made dtype required as well * address nits * skip per tensor weight only test for now * add static quant * add static quant * update * static quant working eager + compile * remove file * added asserts * undo smoothquant change * fix return * address cr feedback
1 parent 08e5e20 commit f99105a

File tree

5 files changed

+191
-23
lines changed

5 files changed

+191
-23
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414

1515
from torchao.quantization import (
1616
Int8DynamicActivationInt8WeightConfig,
17+
Int8StaticActivationInt8WeightConfig,
1718
Int8WeightOnlyConfig,
1819
quantize_,
1920
)
2021
from torchao.quantization.granularity import PerRow, PerTensor
2122
from torchao.quantization.quant_primitives import MappingType
23+
from torchao.quantization.quantize_.common import (
24+
_choose_quant_func_and_quantize_tensor,
25+
)
2226
from torchao.quantization.utils import compute_error, get_block_size
2327
from torchao.testing.model_architectures import ToyTwoLinearModel
2428
from torchao.testing.utils import TorchAOIntegrationTestCase
@@ -221,5 +225,66 @@ def test_available_gpu_kernels(self):
221225
).check_count("triton_poi_fused", 1).run(code[0])
222226

223227

228+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
229+
@common_utils.instantiate_parametrized_tests
230+
class TestInt8StaticQuant(TorchAOIntegrationTestCase):
231+
@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
232+
@common_utils.parametrize("dtype", [torch.bfloat16])
233+
def test_static_activation_per_row_int8_weight(self, granularity, dtype):
234+
torch.compiler.reset()
235+
236+
M, N, K = 32, 32, 32
237+
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")
238+
239+
model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
240+
model_static_quant = copy.deepcopy(model)
241+
model_dynamic_quant = copy.deepcopy(model)
242+
243+
model_out_baseline = model(input_tensor)
244+
245+
dynamic_config = Int8DynamicActivationInt8WeightConfig(
246+
version=2, granularity=granularity
247+
)
248+
quantize_(model_dynamic_quant, dynamic_config)
249+
250+
dynamic_out_eager = model_dynamic_quant(input_tensor)
251+
sqnr_dynamic_eager = compute_error(model_out_baseline, dynamic_out_eager)
252+
253+
model_dynamic_quant = torch.compile(model_dynamic_quant, fullgraph=True)
254+
255+
dynamic_out_compile = model_dynamic_quant(input_tensor)
256+
sqnr_dynamic_compile = compute_error(model_out_baseline, dynamic_out_compile)
257+
258+
# we use eager scales to calculate
259+
int8_input = _choose_quant_func_and_quantize_tensor(
260+
input_tensor, model_dynamic_quant.weight.act_quant_kwargs
261+
)
262+
263+
static_config = Int8StaticActivationInt8WeightConfig(
264+
scale=int8_input.scale.detach().clone(),
265+
granularity=granularity,
266+
)
267+
quantize_(model_static_quant, static_config)
268+
269+
static_out_eager = model_static_quant(input_tensor)
270+
sqnr_static_eager = compute_error(model_out_baseline, static_out_eager)
271+
272+
model_static_quant = torch.compile(model_static_quant, fullgraph=True)
273+
274+
static_out_compile = model_dynamic_quant(input_tensor)
275+
sqnr_static_compile = compute_error(model_out_baseline, static_out_compile)
276+
277+
assert (
278+
sqnr_static_compile
279+
== sqnr_static_eager
280+
== sqnr_dynamic_compile
281+
== sqnr_dynamic_eager
282+
), "SQNR should be the same for all quantization methods and eager/compile"
283+
284+
# eager numerics should match exactly
285+
# for compile, we can't compare dynamic vs static because we may get slightly different qparams when fused
286+
torch.testing.assert_close(dynamic_out_eager, static_out_eager)
287+
288+
224289
if __name__ == "__main__":
225290
common_utils.run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
Int8DynamicActivationInt4WeightConfig,
6060
Int8DynamicActivationInt8WeightConfig,
6161
Int8DynamicActivationIntxWeightConfig,
62+
Int8StaticActivationInt8WeightConfig,
6263
Int8WeightOnlyConfig,
6364
IntxWeightOnlyConfig,
6465
ModuleFqnToConfig,
@@ -150,6 +151,7 @@
150151
"Int8DynamicActivationInt4WeightConfig",
151152
"Int8DynamicActivationInt8WeightConfig",
152153
"Int8DynamicActivationIntxWeightConfig",
154+
"Int8StaticActivationInt8WeightConfig",
153155
"Int4WeightOnlyConfig",
154156
"Float8DynamicActivationInt4WeightConfig",
155157
"Int8WeightOnlyConfig",

torchao/quantization/quant_api.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
IntxPackingFormat,
8989
IntxUnpackedToInt8Tensor,
9090
QuantizeTensorToFloat8Kwargs,
91+
QuantizeTensorToInt8Kwargs,
9192
)
9293
from torchao.quantization.transform_module import (
9394
_QUANTIZE_CONFIG_HANDLER,
@@ -1590,10 +1591,6 @@ def get_weight_block_size(x):
15901591
)
15911592
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
15921593
else:
1593-
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
1594-
QuantizeTensorToInt8Kwargs,
1595-
)
1596-
15971594
assert config.granularity in {PerRow(), PerTensor()}, (
15981595
"Only PerRow and PerTensor are supported"
15991596
)
@@ -1621,7 +1618,10 @@ def get_weight_block_size(x):
16211618

16221619
@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
16231620
def _int8_dynamic_activation_int8_weight_transform(
1624-
module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig
1621+
module: torch.nn.Module,
1622+
config: Int8DynamicActivationInt8WeightConfig,
1623+
*,
1624+
parameter_name="weight",
16251625
) -> torch.nn.Module:
16261626
if config.set_inductor_config:
16271627
torchao.quantization.utils.recommended_inductor_config_setter()
@@ -1634,7 +1634,88 @@ def _int8_dynamic_activation_int8_weight_transform(
16341634
module.weight, config
16351635
)
16361636
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1637-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1637+
module.extra_repr = types.MethodType(
1638+
partial(
1639+
_module_extra_repr,
1640+
original_extra_repr=module.extra_repr,
1641+
parameter_name=parameter_name,
1642+
),
1643+
module,
1644+
)
1645+
return module
1646+
1647+
1648+
@dataclass
1649+
class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
1650+
"""
1651+
Configuration for applying int8 static symmetric quantization to both activation and weight
1652+
1653+
Args:
1654+
scale (torch.Tensor): The scale tensor for activation quantization.
1655+
granularity (Granularity): The granularity of quantization. PerRow() and PerTensor() are supported currently
1656+
act_mapping_type (MappingType): The mapping type for activation quantization. only SYMMETRIC is supported currently
1657+
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1658+
version (int): the version of the config
1659+
"""
1660+
1661+
scale: torch.Tensor
1662+
granularity: Granularity = PerRow()
1663+
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
1664+
set_inductor_config: bool = True
1665+
version: int = 1
1666+
1667+
def __post_init__(self):
1668+
torch._C._log_api_usage_once(
1669+
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
1670+
)
1671+
1672+
1673+
@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
1674+
def _int8_static_activation_int8_weight_transform(
1675+
module: torch.nn.Module,
1676+
config: Int8StaticActivationInt8WeightConfig,
1677+
*,
1678+
parameter_name="weight",
1679+
):
1680+
assert config.granularity in {PerRow(), PerTensor()}, (
1681+
"Only PerRow and PerTensor is supported currently"
1682+
)
1683+
assert config.act_mapping_type == MappingType.SYMMETRIC, (
1684+
"asymmetric static quant not supported currently"
1685+
)
1686+
assert hasattr(module, parameter_name), (
1687+
f"Expected module to have attribute `{parameter_name}` but not found"
1688+
)
1689+
1690+
if config.set_inductor_config:
1691+
torchao.quantization.utils.recommended_inductor_config_setter()
1692+
1693+
activation_granularity = config.granularity
1694+
weight_granularity = config.granularity
1695+
1696+
quantized_tensor = Int8Tensor.from_hp(
1697+
getattr(module, parameter_name),
1698+
granularity=weight_granularity,
1699+
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
1700+
granularity=activation_granularity,
1701+
mapping_type=config.act_mapping_type,
1702+
),
1703+
act_scale=config.scale.detach(),
1704+
)
1705+
1706+
setattr(
1707+
module,
1708+
parameter_name,
1709+
torch.nn.Parameter(quantized_tensor, requires_grad=False),
1710+
)
1711+
module.extra_repr = types.MethodType(
1712+
partial(
1713+
_module_extra_repr,
1714+
original_extra_repr=module.extra_repr,
1715+
parameter_name=parameter_name,
1716+
),
1717+
module,
1718+
)
16381719
return module
16391720

16401721

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import abc
8-
from typing import ClassVar
8+
from typing import ClassVar, Optional
99

1010
import torch
1111

@@ -31,7 +31,9 @@ def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs)
3131

3232

3333
def _choose_quant_func_and_quantize_tensor(
34-
tensor: torch.Tensor, quant_kwargs: QuantizeTensorKwargs
34+
tensor: torch.Tensor,
35+
quant_kwargs: QuantizeTensorKwargs,
36+
scale: Optional[torch.Tensor] = None,
3537
) -> torch.Tensor:
3638
"""Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs
3739
quantizes tensor to the derived dtype chosen in (1)
@@ -60,6 +62,7 @@ def _choose_quant_func_and_quantize_tensor(
6062
tensor,
6163
quant_kwargs.granularity,
6264
mapping_type=quant_kwargs.mapping_type,
65+
scale=scale,
6366
)
6467

6568
raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ class Int8Tensor(TorchAOBaseTensor):
5353
Tensor Attributes:
5454
qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D)
5555
scale: scale factors for dequantization
56-
# TODO: Static quantization support using `static_scale`
5756
5857
Non-Tensor Attributes:
5958
granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
6059
act_quant_kwargs: flags for dynamic activation quantization
6160
"""
6261

63-
# TODO: Static quantization support using `static_scale`
6462
tensor_data_names = ["qdata", "scale"]
63+
optional_tensor_data_names = ["act_scale"]
6564
tensor_attribute_names = ["block_size", "dtype"]
6665
optional_tensor_attribute_names = [
6766
"act_quant_kwargs",
@@ -73,6 +72,7 @@ def __new__(
7372
scale: torch.Tensor,
7473
block_size: List[int],
7574
dtype: torch.dtype,
75+
act_scale=None,
7676
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
7777
):
7878
kwargs = {
@@ -88,6 +88,7 @@ def __init__(
8888
scale: torch.Tensor,
8989
block_size: List[int],
9090
dtype: torch.dtype,
91+
act_scale=None,
9192
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
9293
):
9394
super().__init__()
@@ -96,13 +97,15 @@ def __init__(
9697
self.block_size = block_size
9798
# don't set dtype because this gets done in __new__
9899
self.act_quant_kwargs = act_quant_kwargs
100+
self.act_scale = act_scale
99101

100102
def __repr__(self):
101103
return (
102104
f"{self.__class__.__name__}("
103105
f"act_quant_kwargs={self.act_quant_kwargs}, "
104106
f"qdata={self.qdata}, "
105107
f"scale={self.scale}, "
108+
f"act_scale={self.act_scale}, "
106109
f"block_size={self.block_size}, "
107110
f"shape={self.shape}, "
108111
f"device={self.device}, "
@@ -114,24 +117,35 @@ def from_hp(
114117
cls,
115118
hp_tensor: torch.Tensor,
116119
granularity: Granularity,
117-
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
118120
mapping_type=MappingType.SYMMETRIC,
121+
scale: Optional[torch.Tensor] = None,
122+
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
123+
act_scale: Optional[torch.Tensor] = None,
119124
):
120125
"""Create Int8Tensor from high-precision tensor"""
121126
block_size = get_block_size(hp_tensor.shape, granularity)
122127
block_size = list(block_size)
123128

124-
scale, zero_point = choose_qparams_affine(
125-
input=hp_tensor,
126-
mapping_type=mapping_type,
127-
block_size=block_size,
128-
target_dtype=torch.int8,
129-
quant_min=-128,
130-
quant_max=127,
131-
scale_dtype=hp_tensor.dtype,
132-
zero_point_dtype=torch.int8,
133-
keepdim=True,
134-
)
129+
if scale is None:
130+
scale, zero_point = choose_qparams_affine(
131+
input=hp_tensor,
132+
mapping_type=mapping_type,
133+
block_size=block_size,
134+
target_dtype=torch.int8,
135+
quant_min=-128,
136+
quant_max=127,
137+
scale_dtype=hp_tensor.dtype,
138+
zero_point_dtype=torch.int8,
139+
keepdim=True,
140+
)
141+
else:
142+
# Scale can be provided in the case of static quant
143+
assert scale.ndim == hp_tensor.ndim
144+
assert all(
145+
(hp_tensor.shape[i] // block_size[i]) == scale.shape[i]
146+
for i in range(hp_tensor.ndim)
147+
)
148+
zero_point = torch.zeros_like(scale, dtype=torch.int8)
135149

136150
int_data = quantize_affine(
137151
hp_tensor,
@@ -146,6 +160,7 @@ def from_hp(
146160
scale,
147161
block_size,
148162
hp_tensor.dtype,
163+
act_scale=act_scale,
149164
act_quant_kwargs=act_quant_kwargs,
150165
)
151166

@@ -185,7 +200,9 @@ def _(func, types, args, kwargs):
185200

186201
if weight_tensor.act_quant_kwargs is not None:
187202
activation_tensor = _choose_quant_func_and_quantize_tensor(
188-
activation_tensor, weight_tensor.act_quant_kwargs
203+
activation_tensor,
204+
weight_tensor.act_quant_kwargs,
205+
scale=weight_tensor.act_scale,
189206
)
190207
# Dynamic activation quantization path
191208

0 commit comments

Comments
 (0)