Skip to content

Commit de854bc

Browse files
committed
add ut for int8 scaled_embeding_bag pattern_match
1 parent 7cbdb86 commit de854bc

File tree

2 files changed

+81
-28
lines changed

2 files changed

+81
-28
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,42 +3047,60 @@ def test_fp8_q_attention_block(self):
30473047
annotate_matmul=annotate_matmul, is_fp8=True
30483048
)
30493049

3050-
@skipIfNoDynamoSupport
3051-
@skipIfNoONEDNN
3052-
@skipIfNoFloat8Support
3053-
@unittest.skipIf(
3054-
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
3055-
reason="cpp kernels not built",
3056-
)
3057-
def test_fp8_scaled_embedding_bag(self):
3058-
dtype = torch.float8_e4m3fn
3059-
3050+
def _test_scaled_embedding_bag_helper(self, dtype, with_quant=False):
30603051
class FP8QDQEmbeddingBag(torch.nn.Module):
30613052
def __init__(self):
30623053
super().__init__()
30633054
self.weight_scale = 2.0
3055+
self.output_scale = 3.0
3056+
3057+
def _dq(self, weight):
3058+
if dtype == torch.float8_e4m3fn:
3059+
res = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
3060+
tensor=weight.data,
3061+
scale=torch.tensor([self.weight_scale]),
3062+
output_dtype=torch.float,
3063+
)
3064+
else:
3065+
res = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
3066+
weight.data,
3067+
self.weight_scale,
3068+
0,
3069+
-128, 127, torch.int8,
3070+
)
3071+
return res
3072+
3073+
def _q(self, x):
3074+
if dtype == torch.float8_e4m3fn:
3075+
qx = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
3076+
tensor=x,
3077+
scale=torch.tensor([self.output_scale]),
3078+
float8_dtype=dtype,
3079+
)
3080+
else:
3081+
qx = torch.ops.quantized_decomposed.quantize_per_tensor.default(
3082+
x, self.output_scale, 0, -128, 127, torch.int8
3083+
)
3084+
return qx
30643085

30653086
def forward(
30663087
self,
30673088
weight,
30683089
input,
30693090
offsets=None,
30703091
):
3071-
weight = (
3072-
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
3073-
tensor=weight.data,
3074-
scale=torch.tensor([self.weight_scale]),
3075-
output_dtype=torch.float,
3076-
)
3077-
)
3092+
weight = self._dq(weight)
30783093

3079-
return torch.nn.functional.embedding_bag(
3094+
res = torch.nn.functional.embedding_bag(
30803095
input,
30813096
weight,
30823097
offsets,
30833098
mode="sum",
30843099
include_last_offset=True,
30853100
)
3101+
if with_quant:
3102+
res = self._q(res)
3103+
return res
30863104

30873105
EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
30883106
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
@@ -3109,8 +3127,11 @@ def forward(
31093127
)
31103128

31113129
def matcher_check_fn():
3130+
counter_name = "scaled_embedding_bag"
3131+
if with_quant:
3132+
counter_name += "_with_quant"
31123133
self.assertEqual(
3113-
counters["inductor"]["scaled_embedding_bag_matcher_count"], 1
3134+
counters["inductor"][f"{counter_name}_matcher_count"], 1
31143135
)
31153136

31163137
self._test_common(
@@ -3120,6 +3141,38 @@ def matcher_check_fn():
31203141
)
31213142

31223143

3144+
@skipIfNoDynamoSupport
3145+
@skipIfNoONEDNN
3146+
@skipIfNoFloat8Support
3147+
@unittest.skipIf(
3148+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
3149+
reason="cpp kernels not built",
3150+
)
3151+
def test_fp8_scaled_embedding_bag(self):
3152+
self._test_scaled_embedding_bag_helper(torch.float8_e4m3fn)
3153+
3154+
@skipIfNoDynamoSupport
3155+
@skipIfNoONEDNN
3156+
@skipIfNoFloat8Support
3157+
@unittest.skipIf(
3158+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
3159+
reason="cpp kernels not built",
3160+
)
3161+
def test_int8_scaled_embedding_bag(self):
3162+
self._test_scaled_embedding_bag_helper(torch.int8)
3163+
3164+
3165+
@skipIfNoDynamoSupport
3166+
@skipIfNoONEDNN
3167+
@skipIfNoFloat8Support
3168+
@unittest.skipIf(
3169+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
3170+
reason="cpp kernels not built",
3171+
)
3172+
def test_int8_scaled_embedding_bag_with_quant(self):
3173+
self._test_scaled_embedding_bag_helper(torch.int8, True)
3174+
3175+
31233176
instantiate_parametrized_tests(TestPatternMatcher)
31243177
if __name__ == "__main__":
31253178
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2893,20 +2893,17 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
28932893
kwargs["mode"],
28942894
kwargs["include_last_offset"],
28952895
)
2896-
# only support fp32 and int8 output on kernel
2897-
# next step to support more output_type
28982896
output_type = torch.float
28992897
o_scale = 1.0
29002898
if "o_dtype" in kwargs:
2901-
output_type = torch.int8
2899+
output_type = kwargs["o_dtype"]
29022900
o_scale = kwargs["o_inv_scale"]
29032901

29042902
graph = match.graph
29052903
with graph.inserting_before(getitem_node):
2906-
# scale type is float on int8 q/dq
2907-
# Not support float scale yet on scaled_embedding_bag
2904+
# float scale not supported on scaled_embedding_bag
29082905
# convert scale from float into tensor
2909-
if output_type == torch.int8:
2906+
if type(w_scale) is float:
29102907
w_scale = graph.call_function(
29112908
torch.ops.aten.full.default,
29122909
args=([1], w_scale),
@@ -2927,7 +2924,7 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
29272924
torch.ops.torchao._scaled_embedding_bag.default, args=new_args
29282925
)
29292926

2930-
# remove quant node
2927+
# Erase quant pattern
29312928
if output_type == torch.int8:
29322929
quant_node.replace_all_uses_with(getitem_node)
29332930
getitem_node.meta.update(quant_node.meta)
@@ -2942,8 +2939,11 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
29422939
# Erase the dequant pattern
29432940
graph.erase_node(dequant_node)
29442941

2945-
counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1
2946-
counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes)
2942+
counter_name = "scaled_embedding_bag"
2943+
if "o_dtype" in kwargs:
2944+
counter_name += "_with_quant"
2945+
counters["inductor"][f"{counter_name}_matcher_count"] += 1
2946+
counters["inductor"][f"{counter_name}_matcher_nodes"] += len(match.nodes)
29472947

29482948

29492949
def _generate_scaled_embedding_bag_patterns(dq_pattern):

0 commit comments

Comments
 (0)