Skip to content

Commit bcfc520

Browse files
committed
up
1 parent dfbd681 commit bcfc520

File tree

2 files changed

+230
-8
lines changed

2 files changed

+230
-8
lines changed

torchao/experimental/quant_api.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,3 +1108,152 @@ def quantize(self, model: nn.Module) -> nn.Module:
11081108
},
11091109
)
11101110
return model
1111+
1112+
1113+
def _get_q_dq_patterns_and_replacements(weight_bit_width, has_weight_zeros, target):
1114+
w_qmin = -(1 << (weight_bit_width - 1))
1115+
w_qmax = (1 << (weight_bit_width - 1)) - 1
1116+
a_qmin = -128
1117+
a_qmax = 127
1118+
1119+
if not has_weight_zeros:
1120+
1121+
def pattern(a, w_int, w_scale, bias, group_size, a_block):
1122+
a_scale, a_zero = torch.ops.quant.choose_qparams_affine.default(
1123+
a,
1124+
"ASYMMETRIC",
1125+
a_block,
1126+
torch.int32,
1127+
a_qmin,
1128+
a_qmax,
1129+
None,
1130+
torch.float32,
1131+
torch.int32,
1132+
)
1133+
q_a = torch.ops.quant.quantize_affine.default(
1134+
a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1135+
)
1136+
dq_a = torch.ops.quant.dequantize_affine.default(
1137+
q_a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1138+
)
1139+
dq_w = torch.ops.quant.dequantize_affine.default(
1140+
w_int,
1141+
[1, group_size],
1142+
w_scale,
1143+
None,
1144+
torch.int32,
1145+
w_qmin,
1146+
w_qmax,
1147+
"NONE",
1148+
)
1149+
return torch.ops.aten.linear.default(dq_a, dq_w, bias)
1150+
1151+
def replacement(a, w_int, w_scale, bias, group_size, a_block):
1152+
n = w_int.size(0)
1153+
k = a_block[-1]
1154+
out_shape = a.shape[:-1] + (n,)
1155+
packed_weight = getattr(
1156+
torch.ops.torchao,
1157+
f"_pack_8bit_act_{weight_bit_width}bit_weight",
1158+
)(
1159+
w_int.to(torch.int8),
1160+
w_scale.reshape(-1),
1161+
None,
1162+
group_size,
1163+
bias,
1164+
target,
1165+
)
1166+
return getattr(
1167+
torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight"
1168+
)(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape)
1169+
else:
1170+
1171+
def pattern(a, w_int, w_scale, w_zero, bias, group_size, a_block):
1172+
a_scale, a_zero = torch.ops.quant.choose_qparams_affine.default(
1173+
a,
1174+
"ASYMMETRIC",
1175+
a_block,
1176+
torch.int32,
1177+
a_qmin,
1178+
a_qmax,
1179+
None,
1180+
torch.float32,
1181+
torch.int32,
1182+
)
1183+
q_a = torch.ops.quant.quantize_affine.default(
1184+
a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1185+
)
1186+
dq_a = torch.ops.quant.dequantize_affine.default(
1187+
q_a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1188+
)
1189+
dq_w = torch.ops.quant.dequantize_affine.default(
1190+
w_int, [1, group_size], w_scale, w_zero, torch.int32, w_qmin, w_qmax
1191+
)
1192+
return torch.ops.aten.linear.default(dq_a, dq_w, bias)
1193+
1194+
def replacement(a, w_int, w_scale, w_zero, bias, group_size, a_block):
1195+
n = w_int.size(0)
1196+
k = a_block[-1]
1197+
out_shape = a.shape[:-1] + (n,)
1198+
packed_weight = getattr(
1199+
torch.ops.torchao,
1200+
f"_pack_8bit_act_{weight_bit_width}bit_weight",
1201+
)(
1202+
w_int.to(torch.int8),
1203+
w_scale.reshape(-1),
1204+
w_zero.reshape(-1).to(torch.int8),
1205+
group_size,
1206+
bias,
1207+
target,
1208+
)
1209+
return getattr(
1210+
torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight"
1211+
)(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape)
1212+
1213+
return pattern, replacement
1214+
1215+
1216+
def replace_q_dq_with_torchao_quantized_linear_ops(
1217+
ep: torch.export.ExportedProgram, target=None
1218+
):
1219+
# TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export)
1220+
assert (
1221+
len(ep.range_constraints) == 0
1222+
), "ExportedProgram with range constraints are not supported"
1223+
1224+
import itertools
1225+
1226+
from torch._export.passes.constant_folding import constant_fold
1227+
from torch.fx import subgraph_rewriter
1228+
1229+
def filter_invalid_a_block(match, x, y):
1230+
"""
1231+
We only want a_block with shape [1, ..., 1, k]
1232+
"""
1233+
a_block_node = [n for n in match.nodes_map if n.name == "a_block"]
1234+
assert len(a_block_node) == 1
1235+
a_block_node = a_block_node[0]
1236+
a_block_node_val = match.nodes_map[a_block_node]
1237+
for v in a_block_node_val[0:-1]:
1238+
if v != 1:
1239+
return False
1240+
return True
1241+
1242+
gm = (
1243+
ep.module()
1244+
) # module() unlifts the inputs, which is needed for constant folding
1245+
for weight_bit_width, has_weight_zeros in itertools.product(
1246+
range(1, 9), [False, True]
1247+
):
1248+
pattern, replacement = _get_q_dq_patterns_and_replacements(
1249+
weight_bit_width, has_weight_zeros, target
1250+
)
1251+
subgraph_rewriter.replace_pattern_with_filters(
1252+
gm, pattern, replacement, match_filters=[filter_invalid_a_block]
1253+
)
1254+
1255+
# Constant fold evaluates and removes the packing ops
1256+
constant_fold(gm)
1257+
1258+
# Re-export
1259+
return torch.export.export(gm, *ep.example_inputs)

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
PackedLinearInt8DynamicActivationIntxWeightLayout,
1818
)
1919
from torchao.experimental.q_dq_layout import QDQLayout
20-
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
20+
from torchao.experimental.quant_api import (
21+
Int8DynamicActivationIntxWeightConfig,
22+
replace_q_dq_with_torchao_quantized_linear_ops,
23+
)
2124
from torchao.quantization.granularity import PerGroup, PerRow
2225
from torchao.quantization.quant_api import quantize_
2326
from torchao.utils import unwrap_tensor_subclass
@@ -79,7 +82,7 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
7982
quantized_model = copy.deepcopy(model)
8083
quantize_(
8184
quantized_model,
82-
int8_dynamic_activation_intx_weight(
85+
Int8DynamicActivationIntxWeightConfig(
8386
weight_dtype=weight_dtype,
8487
granularity=granularity,
8588
has_weight_zeros=has_weight_zeros,
@@ -91,7 +94,7 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
9194
quantized_model_reference = copy.deepcopy(model)
9295
quantize_(
9396
quantized_model_reference,
94-
int8_dynamic_activation_intx_weight(
97+
Int8DynamicActivationIntxWeightConfig(
9598
weight_dtype=weight_dtype,
9699
granularity=granularity,
97100
has_weight_zeros=has_weight_zeros,
@@ -124,7 +127,7 @@ def test_accuracy_aten(self):
124127
quantized_model = copy.deepcopy(model)
125128
quantize_(
126129
quantized_model,
127-
int8_dynamic_activation_intx_weight(
130+
Int8DynamicActivationIntxWeightConfig(
128131
weight_dtype=weight_dtype,
129132
granularity=granularity,
130133
has_weight_zeros=has_weight_zeros,
@@ -136,7 +139,7 @@ def test_accuracy_aten(self):
136139
quantized_model_reference = copy.deepcopy(model)
137140
quantize_(
138141
quantized_model_reference,
139-
int8_dynamic_activation_intx_weight(
142+
Int8DynamicActivationIntxWeightConfig(
140143
weight_dtype=weight_dtype,
141144
granularity=granularity,
142145
has_weight_zeros=has_weight_zeros,
@@ -183,7 +186,7 @@ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
183186

184187
quantize_(
185188
model,
186-
int8_dynamic_activation_intx_weight(
189+
Int8DynamicActivationIntxWeightConfig(
187190
weight_dtype=weight_dtype,
188191
granularity=granularity,
189192
has_weight_zeros=has_weight_zeros,
@@ -245,7 +248,7 @@ def test_export_dynamic_shape_PackedLinearInt8DynamicActivationIntxWeightLayout(
245248

246249
quantize_(
247250
model,
248-
int8_dynamic_activation_intx_weight(
251+
Int8DynamicActivationIntxWeightConfig(
249252
weight_dtype=weight_dtype,
250253
granularity=granularity,
251254
has_weight_zeros=has_weight_zeros,
@@ -278,7 +281,7 @@ def test_export_QDQLayout(self):
278281

279282
quantize_(
280283
model,
281-
int8_dynamic_activation_intx_weight(
284+
Int8DynamicActivationIntxWeightConfig(
282285
weight_dtype=weight_dtype,
283286
granularity=granularity,
284287
has_weight_zeros=has_weight_zeros,
@@ -304,6 +307,76 @@ def test_export_QDQLayout(self):
304307
exported.graph_module.code
305308
)
306309

310+
def test_replace_q_dq_with_torchao_quantized_linear_ops(self):
311+
layers = [
312+
torch.nn.Linear(256, 128, bias=True),
313+
torch.nn.Linear(128, 64, bias=False),
314+
torch.nn.Linear(64, 32, bias=True),
315+
]
316+
model = torch.nn.Sequential(*layers)
317+
activations = torch.randn(2, 1, 256, dtype=torch.float32)
318+
quantize_(
319+
model,
320+
Int8DynamicActivationIntxWeightConfig(
321+
weight_dtype=torch.int4,
322+
granularity=PerGroup(64),
323+
has_weight_zeros=True,
324+
layout=QDQLayout(),
325+
),
326+
lambda m, fqn: fqn == "0",
327+
)
328+
quantize_(
329+
model,
330+
Int8DynamicActivationIntxWeightConfig(
331+
weight_dtype=torch.int3,
332+
granularity=PerRow(),
333+
has_weight_zeros=False,
334+
layout=QDQLayout(),
335+
),
336+
lambda m, fqn: fqn == "1",
337+
)
338+
quantize_(
339+
model,
340+
Int8DynamicActivationIntxWeightConfig(
341+
weight_dtype=torch.int5,
342+
granularity=PerGroup(32),
343+
has_weight_zeros=False,
344+
layout=QDQLayout(),
345+
),
346+
lambda m, fqn: fqn == "2",
347+
)
348+
349+
eager_results = model(activations)
350+
351+
unwrap_tensor_subclass(model)
352+
exported = torch.export.export(model, (activations,), strict=True)
353+
exported = replace_q_dq_with_torchao_quantized_linear_ops(exported)
354+
355+
# We should not find pack op because it gets constant folded
356+
FileCheck().check_not("torch.ops.torchao._pack_8bit_act").run(
357+
exported.graph_module.code
358+
)
359+
360+
# We should find 3 torchao linear ops
361+
FileCheck().check_count(
362+
"torch.ops.torchao._linear_8bit_act_", count=3, exactly=True
363+
).run(exported.graph_module.code)
364+
365+
# We should not find Q/DQ ops
366+
FileCheck().check_not("torch.ops.quant.quantize_affine.default").run(
367+
exported.graph_module.code
368+
)
369+
FileCheck().check_not("torch.ops.quant.dequantize_affine.default").run(
370+
exported.graph_module.code
371+
)
372+
FileCheck().check_not("torch.ops.quant.choose_qparams_affine.default").run(
373+
exported.graph_module.code
374+
)
375+
376+
# Numerics should match
377+
exported_results = exported.module()(activations)
378+
self.assertTrue(torch.allclose(exported_results, eager_results))
379+
307380

308381
if __name__ == "__main__":
309382
unittest.main()

0 commit comments

Comments
 (0)