Skip to content

Commit 4afcea8

Browse files
committed
up
1 parent 78db1bd commit 4afcea8

File tree

4 files changed

+606
-706
lines changed

4 files changed

+606
-706
lines changed

torchao/experimental/quant_api.py

Lines changed: 2 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -780,11 +780,11 @@ def _int8_dynamic_activation_intx_weight_transform(
780780
x,
781781
mapping_type=act_mapping_type,
782782
block_size=_get_per_token_block_size(x),
783-
target_dtype=torch.int8,
783+
target_dtype=torch.int32,
784784
quant_min=-128, # lower bound of int8
785785
quant_max=127, # upper bound of int8
786786
scale_dtype=torch.float32,
787-
zero_point_dtype=torch.int8,
787+
zero_point_dtype=torch.int32,
788788
)
789789
weight = to_linear_activation_quantized(weight, activation_quant_func)
790790

@@ -1108,152 +1108,3 @@ def quantize(self, model: nn.Module) -> nn.Module:
11081108
},
11091109
)
11101110
return model
1111-
1112-
1113-
def _get_q_dq_patterns_and_replacements(weight_bit_width, has_weight_zeros, target):
1114-
w_qmin = -(1 << (weight_bit_width - 1))
1115-
w_qmax = (1 << (weight_bit_width - 1)) - 1
1116-
a_qmin = -128
1117-
a_qmax = 127
1118-
1119-
if not has_weight_zeros:
1120-
1121-
def pattern(a, w_int, w_scale, bias, group_size, a_block):
1122-
a_scale, a_zero = torch.ops.quant.choose_qparams_affine.default(
1123-
a,
1124-
"ASYMMETRIC",
1125-
a_block,
1126-
torch.int32,
1127-
a_qmin,
1128-
a_qmax,
1129-
None,
1130-
torch.float32,
1131-
torch.int32,
1132-
)
1133-
q_a = torch.ops.quant.quantize_affine.default(
1134-
a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1135-
)
1136-
dq_a = torch.ops.quant.dequantize_affine.default(
1137-
q_a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1138-
)
1139-
dq_w = torch.ops.quant.dequantize_affine.default(
1140-
w_int,
1141-
[1, group_size],
1142-
w_scale,
1143-
None,
1144-
torch.int32,
1145-
w_qmin,
1146-
w_qmax,
1147-
"NONE",
1148-
)
1149-
return torch.ops.aten.linear.default(dq_a, dq_w, bias)
1150-
1151-
def replacement(a, w_int, w_scale, bias, group_size, a_block):
1152-
n = w_int.size(0)
1153-
k = a_block[-1]
1154-
out_shape = a.shape[:-1] + (n,)
1155-
packed_weight = getattr(
1156-
torch.ops.torchao,
1157-
f"_pack_8bit_act_{weight_bit_width}bit_weight",
1158-
)(
1159-
w_int.to(torch.int8),
1160-
w_scale.reshape(-1),
1161-
None,
1162-
group_size,
1163-
bias,
1164-
target,
1165-
)
1166-
return getattr(
1167-
torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight"
1168-
)(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape)
1169-
else:
1170-
1171-
def pattern(a, w_int, w_scale, w_zero, bias, group_size, a_block):
1172-
a_scale, a_zero = torch.ops.quant.choose_qparams_affine.default(
1173-
a,
1174-
"ASYMMETRIC",
1175-
a_block,
1176-
torch.int32,
1177-
a_qmin,
1178-
a_qmax,
1179-
None,
1180-
torch.float32,
1181-
torch.int32,
1182-
)
1183-
q_a = torch.ops.quant.quantize_affine.default(
1184-
a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1185-
)
1186-
dq_a = torch.ops.quant.dequantize_affine.default(
1187-
q_a, a_block, a_scale, a_zero, torch.int32, a_qmin, a_qmax
1188-
)
1189-
dq_w = torch.ops.quant.dequantize_affine.default(
1190-
w_int, [1, group_size], w_scale, w_zero, torch.int32, w_qmin, w_qmax
1191-
)
1192-
return torch.ops.aten.linear.default(dq_a, dq_w, bias)
1193-
1194-
def replacement(a, w_int, w_scale, w_zero, bias, group_size, a_block):
1195-
n = w_int.size(0)
1196-
k = a_block[-1]
1197-
out_shape = a.shape[:-1] + (n,)
1198-
packed_weight = getattr(
1199-
torch.ops.torchao,
1200-
f"_pack_8bit_act_{weight_bit_width}bit_weight",
1201-
)(
1202-
w_int.to(torch.int8),
1203-
w_scale.reshape(-1),
1204-
w_zero.reshape(-1).to(torch.int8),
1205-
group_size,
1206-
bias,
1207-
target,
1208-
)
1209-
return getattr(
1210-
torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight"
1211-
)(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape)
1212-
1213-
return pattern, replacement
1214-
1215-
1216-
def replace_q_dq_with_torchao_quantized_linear_ops(
1217-
ep: torch.export.ExportedProgram, target=None
1218-
):
1219-
# TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export)
1220-
assert (
1221-
len(ep.range_constraints) == 0
1222-
), "ExportedProgram with range constraints are not supported"
1223-
1224-
import itertools
1225-
1226-
from torch._export.passes.constant_folding import constant_fold
1227-
from torch.fx import subgraph_rewriter
1228-
1229-
def filter_invalid_a_block(match, x, y):
1230-
"""
1231-
We only want a_block with shape [1, ..., 1, k]
1232-
"""
1233-
a_block_node = [n for n in match.nodes_map if n.name == "a_block"]
1234-
assert len(a_block_node) == 1
1235-
a_block_node = a_block_node[0]
1236-
a_block_node_val = match.nodes_map[a_block_node]
1237-
for v in a_block_node_val[0:-1]:
1238-
if v != 1:
1239-
return False
1240-
return True
1241-
1242-
gm = (
1243-
ep.module()
1244-
) # module() unlifts the inputs, which is needed for constant folding
1245-
for weight_bit_width, has_weight_zeros in itertools.product(
1246-
range(1, 9), [False, True]
1247-
):
1248-
pattern, replacement = _get_q_dq_patterns_and_replacements(
1249-
weight_bit_width, has_weight_zeros, target
1250-
)
1251-
subgraph_rewriter.replace_pattern_with_filters(
1252-
gm, pattern, replacement, match_filters=[filter_invalid_a_block]
1253-
)
1254-
1255-
# Constant fold evaluates and removes the packing ops
1256-
constant_fold(gm)
1257-
1258-
# Re-export
1259-
return torch.export.export(gm, *ep.example_inputs)
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import itertools
9+
from collections import defaultdict
10+
from typing import Callable, Optional
11+
12+
import torch
13+
from torch._export.passes.constant_folding import (
14+
ConstantFolder,
15+
replace_node_with_constant,
16+
)
17+
from torch.fx import subgraph_rewriter
18+
19+
20+
def constant_fold(
21+
gm: torch.fx.GraphModule,
22+
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
23+
skip_constructors: bool = False,
24+
):
25+
with torch.utils._python_dispatch._disable_current_modes():
26+
# The ConstantFolder has a bug where it throws if dequantize_affine is not defined
27+
# TODO: fix upstream
28+
try:
29+
getattr(torch.ops.pt2e_quant, "dequantize_affine")
30+
except AttributeError:
31+
setattr(torch.ops.pt2e_quant, "dequantize_affine", None)
32+
33+
cf = ConstantFolder(gm, skip_constructors)
34+
cf.run()
35+
36+
for node, constant in cf.node_replacements.items():
37+
if constraint_fn is not None and not constraint_fn(node):
38+
continue
39+
replace_node_with_constant(gm, node, constant)
40+
41+
erased_params = []
42+
# Get all attr users by looking up the graph instead from node.users, because in this case
43+
# _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
44+
45+
# opcode name target args kwargs
46+
# ------------- ------------------- ---------------- --------------------------- --------
47+
# placeholder arg0_1 arg0 () {}
48+
# get_attr _tensor_constant0 state () {}
49+
# call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
50+
# get_attr _tensor_constant0_1 state () {}
51+
# call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
52+
# output output output ([add],) {}
53+
54+
get_attr_node_users = defaultdict(list)
55+
for node in gm.graph.nodes:
56+
if node.op == "get_attr":
57+
get_attr_node_users[node.target].extend(node.users.keys())
58+
for node in gm.graph.find_nodes(op="get_attr"):
59+
if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
60+
if hasattr(gm, node.target):
61+
delattr(gm, node.target)
62+
erased_params.append(node)
63+
for node in erased_params:
64+
gm.graph.erase_node(node)
65+
66+
gm.graph.eliminate_dead_code()
67+
gm.graph.lint()
68+
gm.recompile()
69+
70+
71+
def _get_q_dq_linear_patterns_replacements_and_filters(
72+
weight_bit_width, has_weight_zeros, target
73+
):
74+
glbs = globals()
75+
glbs["weight_bit_width"] = weight_bit_width
76+
glbs["target"] = target
77+
glbs["w_quant_min"] = -(1 << (weight_bit_width - 1))
78+
glbs["w_quant_max"] = (1 << (weight_bit_width - 1)) - 1
79+
glbs["a_quant_min"] = -128
80+
glbs["a_quant_max"] = 127
81+
glbs["a_mapping_type"] = "ASYMMETRIC"
82+
glbs["a_scale_dtype"] = torch.float32
83+
glbs["a_eps"] = None
84+
85+
lcls = {}
86+
87+
pattern_str = f"""
88+
def pattern(
89+
a, a_block_size, a_target_dtype, a_zero_point_dtype,
90+
w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype,
91+
bias):
92+
a_scale, a_zero_point = torch.ops.quant.choose_qparams_affine.default(
93+
a,
94+
a_mapping_type,
95+
a_block_size,
96+
a_target_dtype,
97+
a_quant_min,
98+
a_quant_max,
99+
a_eps,
100+
a_scale_dtype,
101+
a_zero_point_dtype,
102+
)
103+
a_int_data = torch.ops.quant.quantize_affine.default(
104+
a, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max,
105+
)
106+
dq_a = torch.ops.quant.dequantize_affine.default(
107+
a_int_data, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max
108+
)
109+
dq_w = torch.ops.quant.dequantize_affine.default(
110+
w_int_data,
111+
w_block_size,
112+
w_scale,
113+
w_zero_point,
114+
w_target_dtype,
115+
w_quant_min,
116+
w_quant_max,
117+
{"'INT'" if has_weight_zeros else "'NONE'"}
118+
)
119+
return torch.ops.aten.linear.default(dq_a, dq_w, bias)
120+
"""
121+
exec(pattern_str, glbs, lcls)
122+
pattern = lcls["pattern"]
123+
124+
replacement_str = f"""
125+
def replacement(
126+
a, a_block_size, a_target_dtype, a_zero_point_dtype,
127+
w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype,
128+
bias,):
129+
n = w_int_data.size(0)
130+
k = a_block_size[-1]
131+
group_size = w_block_size[-1]
132+
out_shape = a.shape[:-1] + (n,)
133+
packed_weight = getattr(
134+
torch.ops.torchao,
135+
f"_pack_8bit_act_{weight_bit_width}bit_weight",
136+
)(
137+
w_int_data.to(torch.int8),
138+
w_scale.reshape(-1),
139+
{"w_zero_point.reshape(-1).to(torch.int8)" if has_weight_zeros else "None"},
140+
group_size,
141+
bias,
142+
target,
143+
)
144+
return getattr(
145+
torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight"
146+
)(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape)
147+
"""
148+
149+
exec(replacement_str, glbs, lcls)
150+
replacement = lcls["replacement"]
151+
152+
def match_filter(match, x, y):
153+
def get_val(name):
154+
node = [n for n in match.nodes_map if n.name == name][0]
155+
return match.nodes_map[node]
156+
157+
int_types = [torch.int8, torch.int16, torch.int32, torch.int64]
158+
159+
a_target_dtype = get_val("a_target_dtype")
160+
if a_target_dtype not in int_types:
161+
return False
162+
163+
a_zero_point_dtype = get_val("a_zero_point_dtype")
164+
if a_zero_point_dtype not in int_types:
165+
return False
166+
167+
# We only want a_block_size with shape [1, ..., 1, k]
168+
a_block_size = get_val("a_block_size")
169+
for d in a_block_size[0:-1]:
170+
if d != 1:
171+
print("a_block_size not [1, ..., 1, k]")
172+
return False
173+
174+
# We only want w_block_size with shape [1, group_size]
175+
w_block_size = get_val("w_block_size")
176+
if len(w_block_size) != 2 or w_block_size[0] != 1:
177+
return False
178+
179+
return True
180+
181+
return pattern, replacement, match_filter
182+
183+
184+
def replace_q_dq_patterns_with_quantized_linear_ops_pass(
185+
ep: torch.export.ExportedProgram,
186+
target=None,
187+
) -> torch.export.ExportedProgram:
188+
"""
189+
This replaces Q/DQ patterns with torchao quantized linear ops.
190+
It is intended for converting Q/DQ nodes exported with QDQLayout to using
191+
the lowbit quantized linear ops.
192+
"""
193+
# TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export)
194+
# See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/
195+
assert (
196+
len(ep.range_constraints) == 0
197+
), "ExportedProgram with range constraints are not supported"
198+
199+
# ep.module() unlifts the weight inputs, which we need for constant folding
200+
gm = ep.module()
201+
for weight_bit_width, has_weight_zeros in itertools.product(
202+
range(1, 9), [True, False]
203+
):
204+
pattern, replacement, match_filter = (
205+
_get_q_dq_linear_patterns_replacements_and_filters(
206+
weight_bit_width, has_weight_zeros, target
207+
)
208+
)
209+
subgraph_rewriter.replace_pattern_with_filters(
210+
gm, pattern, replacement, match_filters=[match_filter]
211+
)
212+
213+
# Constant fold evaluates and removes the packing ops
214+
constant_fold(gm)
215+
216+
# Re-export
217+
return torch.export.export(gm, *ep.example_inputs)

0 commit comments

Comments
 (0)