Skip to content

Commit 3c1b335

Browse files
committed
Remove config functions like int4_weight_only (#3145)
**Summary:** As a follow-up to #2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI
1 parent 69ce0fd commit 3c1b335

File tree

5 files changed

+21
-151
lines changed

5 files changed

+21
-151
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
270270
271271
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
272272
273-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
273+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
274274
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
275275
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
276276

test/quantization/test_quant_api.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -801,38 +801,28 @@ def test_int4wo_cuda_serialization(self):
801801

802802
def test_config_deprecation(self):
803803
"""
804-
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
804+
Test that old config functions like `Int8DynamicActivationInt4WeightConfig` trigger deprecation warnings.
805805
"""
806806
from torchao.quantization import (
807-
float8_dynamic_activation_float8_weight,
808-
float8_static_activation_float8_weight,
809-
float8_weight_only,
810-
fpx_weight_only,
811-
gemlite_uintx_weight_only,
812-
int4_dynamic_activation_int4_weight,
813-
int4_weight_only,
814-
int8_dynamic_activation_int4_weight,
815-
int8_dynamic_activation_int8_weight,
816-
int8_weight_only,
817-
uintx_weight_only,
807+
Float8StaticActivationFloat8WeightConfig,
808+
FPXWeightOnlyConfig,
809+
GemliteUIntXWeightOnlyConfig,
810+
Int4DynamicActivationInt4WeightConfig,
811+
Int8DynamicActivationInt4WeightConfig,
812+
UIntXWeightOnlyConfig,
818813
)
819814

820815
# Reset deprecation warning state, otherwise we won't log warnings here
821816
warnings.resetwarnings()
822817

823818
# Map from deprecated API to the args needed to instantiate it
824819
deprecated_apis_to_args = {
825-
float8_dynamic_activation_float8_weight: (),
826-
float8_static_activation_float8_weight: (torch.randn(3)),
827-
float8_weight_only: (),
828-
fpx_weight_only: (3, 2),
829-
gemlite_uintx_weight_only: (),
830-
int4_dynamic_activation_int4_weight: (),
831-
int4_weight_only: (),
832-
int8_dynamic_activation_int4_weight: (),
833-
int8_dynamic_activation_int8_weight: (),
834-
int8_weight_only: (),
835-
uintx_weight_only: (torch.uint4,),
820+
Float8StaticActivationFloat8WeightConfig: (torch.randn(3),),
821+
FPXWeightOnlyConfig: (3, 2),
822+
GemliteUIntXWeightOnlyConfig: (),
823+
Int4DynamicActivationInt4WeightConfig: (),
824+
Int8DynamicActivationInt4WeightConfig: (),
825+
UIntXWeightOnlyConfig: (torch.uint4,),
836826
}
837827

838828
# Call each deprecated API twice
@@ -841,19 +831,16 @@ def test_config_deprecation(self):
841831
cls(*args)
842832
cls(*args)
843833

844-
# Each call should have at least one warning.
845-
# Some of them can have two warnings - one for deprecation,
846-
# one for moving to prototype
847-
# 1 warning - just deprecation
848-
# 2 warnings - deprecation and prototype warnings
849-
self.assertTrue(len(_warnings) in (1, 2))
834+
self.assertTrue(len(_warnings) == 1)
850835
found_deprecated = False
851836
for w in _warnings:
852-
if "is deprecated and will be removed in a future release" in str(
837+
if "will be moving to prototype in a future release" in str(
853838
w.message
854839
):
855840
found_deprecated = True
856-
self.assertTrue(found_deprecated)
841+
self.assertTrue(
842+
found_deprecated, f"did not find deprecated warning for {cls}"
843+
)
857844

858845

859846
common_utils.instantiate_parametrized_tests(TestQuantFlow)

torchao/quantization/__init__.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,10 @@
6565
PlainLayout,
6666
TensorCoreTiledLayout,
6767
UIntXWeightOnlyConfig,
68-
float8_dynamic_activation_float8_weight,
69-
float8_static_activation_float8_weight,
70-
float8_weight_only,
71-
fpx_weight_only,
7268
fqn_matches_fqn_config,
73-
gemlite_uintx_weight_only,
74-
int4_dynamic_activation_int4_weight,
75-
int4_weight_only,
76-
int8_dynamic_activation_int4_weight,
77-
int8_dynamic_activation_int8_semi_sparse_weight,
78-
int8_dynamic_activation_int8_weight,
79-
int8_weight_only,
8069
intx_quantization_aware_training,
8170
quantize_,
8271
swap_conv2d_1x1_to_linear,
83-
uintx_weight_only,
8472
)
8573
from .quant_primitives import (
8674
MappingType,
@@ -131,20 +119,8 @@
131119
"ALL_AUTOQUANT_CLASS_LIST",
132120
# top level API - manual
133121
"quantize_",
134-
"int4_dynamic_activation_int4_weight",
135-
"int8_dynamic_activation_int4_weight",
136-
"int8_dynamic_activation_int8_weight",
137-
"int8_dynamic_activation_int8_semi_sparse_weight",
138-
"int4_weight_only",
139-
"int8_weight_only",
140122
"intx_quantization_aware_training",
141-
"float8_weight_only",
142-
"float8_dynamic_activation_float8_weight",
143-
"float8_static_activation_float8_weight",
144-
"uintx_weight_only",
145-
"fpx_weight_only",
146123
"fqn_matches_fqn_config",
147-
"gemlite_uintx_weight_only",
148124
"swap_conv2d_1x1_to_linear",
149125
"Int4DynamicActivationInt4WeightConfig",
150126
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@
9898
to_weight_tensor_with_linear_activation_quantization_metadata,
9999
)
100100
from torchao.utils import (
101-
_ConfigDeprecationWrapper,
102101
is_MI300,
103102
is_sm_at_least_89,
104103
is_sm_at_least_90,
@@ -147,18 +146,7 @@
147146
"autoquant",
148147
"_get_subclass_inserter",
149148
"quantize_",
150-
"int8_dynamic_activation_int4_weight",
151-
"int8_dynamic_activation_int8_weight",
152-
"int8_dynamic_activation_int8_semi_sparse_weight",
153-
"int4_weight_only",
154-
"int8_weight_only",
155149
"intx_quantization_aware_training",
156-
"float8_weight_only",
157-
"uintx_weight_only",
158-
"fpx_weight_only",
159-
"gemlite_uintx_weight_only",
160-
"float8_dynamic_activation_float8_weight",
161-
"float8_static_activation_float8_weight",
162150
"Int8DynActInt4WeightQuantizer",
163151
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
164152
"ModuleFqnToConfig",
@@ -478,7 +466,7 @@ def quantize_(
478466
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
479467
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
480468
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
481-
from torchao.quantization.quant_api import int4_weight_only
469+
from torchao.quantization.quant_api import Int4WeightOnlyConfig
482470
483471
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
484472
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -610,12 +598,6 @@ def __post_init__(self):
610598
)
611599

612600

613-
# for BC
614-
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
615-
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
616-
)
617-
618-
619601
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
620602
def _int8_dynamic_activation_int4_weight_transform(
621603
module: torch.nn.Module,
@@ -984,12 +966,6 @@ def __post_init__(self):
984966
)
985967

986968

987-
# for bc
988-
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
989-
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
990-
)
991-
992-
993969
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
994970
def _int4_dynamic_activation_int4_weight_transform(
995971
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1050,12 +1026,6 @@ def __post_init__(self):
10501026
)
10511027

10521028

1053-
# for BC
1054-
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1055-
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1056-
)
1057-
1058-
10591029
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10601030
def _gemlite_uintx_weight_only_transform(
10611031
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1133,11 +1103,6 @@ def __post_init__(self):
11331103
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11341104

11351105

1136-
# for BC
1137-
# TODO maybe change other callsites
1138-
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1139-
1140-
11411106
def _int4_weight_only_quantize_tensor(weight, config):
11421107
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11431108
# file clean of implementation details
@@ -1347,10 +1312,6 @@ def __post_init__(self):
13471312
)
13481313

13491314

1350-
# for BC
1351-
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1352-
1353-
13541315
def _int8_weight_only_quantize_tensor(weight, config):
13551316
if config.version == 1:
13561317
warnings.warn(
@@ -1536,12 +1497,6 @@ def __post_init__(self):
15361497
)
15371498

15381499

1539-
# for BC
1540-
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1541-
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1542-
)
1543-
1544-
15451500
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15461501
if config.version == 1:
15471502
layout = config.layout
@@ -1675,12 +1630,6 @@ def __post_init__(self):
16751630
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
16761631

16771632

1678-
# for BC
1679-
float8_weight_only = _ConfigDeprecationWrapper(
1680-
"float8_weight_only", Float8WeightOnlyConfig
1681-
)
1682-
1683-
16841633
def _float8_weight_only_quant_tensor(weight, config):
16851634
if config.version == 1:
16861635
warnings.warn(
@@ -1865,12 +1814,6 @@ def __post_init__(self):
18651814
self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum)
18661815

18671816

1868-
# for bc
1869-
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1870-
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1871-
)
1872-
1873-
18741817
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18751818
activation_dtype = config.activation_dtype
18761819
weight_dtype = config.weight_dtype
@@ -2079,12 +2022,6 @@ def __post_init__(self):
20792022
)
20802023

20812024

2082-
# for bc
2083-
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
2084-
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
2085-
)
2086-
2087-
20882025
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
20892026
def _float8_static_activation_float8_weight_transform(
20902027
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2170,12 +2107,6 @@ def __post_init__(self):
21702107
)
21712108

21722109

2173-
# for BC
2174-
uintx_weight_only = _ConfigDeprecationWrapper(
2175-
"uintx_weight_only", UIntXWeightOnlyConfig
2176-
)
2177-
2178-
21792110
@register_quantize_module_handler(UIntXWeightOnlyConfig)
21802111
def _uintx_weight_only_transform(
21812112
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2469,10 +2400,6 @@ def __post_init__(self):
24692400
)
24702401

24712402

2472-
# for BC
2473-
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2474-
2475-
24762403
@register_quantize_module_handler(FPXWeightOnlyConfig)
24772404
def _fpx_weight_only_transform(
24782405
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
import itertools
99
import re
1010
import time
11-
import warnings
1211
from functools import reduce
1312
from importlib.metadata import version
1413
from math import gcd
15-
from typing import Any, Callable, Optional, Type
14+
from typing import Any, Callable, Optional
1615

1716
import torch
1817
import torch.nn.utils.parametrize as parametrize
@@ -375,25 +374,6 @@ def torch_version_at_least(min_version):
375374
return parse_version(torch.__version__) >= parse_version(min_version)
376375

377376

378-
class _ConfigDeprecationWrapper:
379-
"""
380-
A deprecation wrapper that directs users from a deprecated "config function"
381-
(e.g. `int4_weight_only`) to the replacement config class.
382-
"""
383-
384-
def __init__(self, deprecated_name: str, config_cls: Type):
385-
self.deprecated_name = deprecated_name
386-
self.config_cls = config_cls
387-
388-
def __call__(self, *args, **kwargs):
389-
warnings.warn(
390-
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
391-
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
392-
f" quantize_(model, {self.config_cls.__name__}(...))"
393-
)
394-
return self.config_cls(*args, **kwargs)
395-
396-
397377
"""
398378
Helper function for implementing aten op or torch function dispatch
399379
and dispatching to these implementations.

0 commit comments

Comments
 (0)