Skip to content

Commit 7cbdb86

Browse files
committed
support int8 scaled_embedding_bag pattern_match
1 parent f6519e3 commit 7cbdb86

File tree

1 file changed

+51
-11
lines changed
  • torchao/quantization/pt2e/inductor_passes

1 file changed

+51
-11
lines changed

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2860,7 +2860,11 @@ def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float3
28602860
def scaled_embedding_bag(match: Match, *args, **kwargs):
28612861
assert dtype in [torch.float32, torch.bfloat16]
28622862

2863-
getitem_node = match.output_node()
2863+
if "o_dtype" in kwargs:
2864+
quant_node = match.output_node()
2865+
getitem_node = quant_node.args[0]
2866+
else:
2867+
getitem_node = match.output_node()
28642868
embedding_bag_node = getitem_node.args[0]
28652869
assert embedding_bag_node.target is aten._embedding_bag_forward_only.default
28662870

@@ -2889,11 +2893,25 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
28892893
kwargs["mode"],
28902894
kwargs["include_last_offset"],
28912895
)
2892-
# only support fp32 output, next step to support more dtype
2896+
# only support fp32 and int8 output on kernel
2897+
# next step to support more output_type
2898+
output_type = torch.float
28932899
o_scale = 1.0
2900+
if "o_dtype" in kwargs:
2901+
output_type = torch.int8
2902+
o_scale = kwargs["o_inv_scale"]
28942903

28952904
graph = match.graph
28962905
with graph.inserting_before(getitem_node):
2906+
# scale type is float on int8 q/dq
2907+
# Not support float scale yet on scaled_embedding_bag
2908+
# convert scale from float into tensor
2909+
if output_type == torch.int8:
2910+
w_scale = graph.call_function(
2911+
torch.ops.aten.full.default,
2912+
args=([1], w_scale),
2913+
kwargs={"dtype": torch.float},
2914+
)
28972915
new_args: tuple[Any, ...] = (
28982916
qw,
28992917
indices,
@@ -2902,13 +2920,18 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
29022920
o_scale,
29032921
mode,
29042922
include_last_offset,
2905-
torch.float,
2923+
output_type,
29062924
)
29072925

29082926
new_embedding_bag_node = graph.call_function(
29092927
torch.ops.torchao._scaled_embedding_bag.default, args=new_args
29102928
)
29112929

2930+
# remove quant node
2931+
if output_type == torch.int8:
2932+
quant_node.replace_all_uses_with(getitem_node)
2933+
getitem_node.meta.update(quant_node.meta)
2934+
graph.erase_node(quant_node)
29122935
getitem_node.replace_all_uses_with(new_embedding_bag_node)
29132936
new_embedding_bag_node.meta.update(embedding_bag_node.meta)
29142937

@@ -2943,20 +2966,37 @@ def _generate_scaled_embedding_bag_patterns(dq_pattern):
29432966

29442967

29452968
def _register_quantization_embeddingbag_pass():
2946-
for dtype in [torch.float32, torch.bfloat16]:
2947-
_register_scaled_embedding_bag_pass(
2948-
_generate_scaled_embedding_bag_patterns(
2969+
for is_fp8 in [True, False]:
2970+
for dtype in [torch.float32, torch.bfloat16]:
2971+
embeddingbag_pattern = _generate_scaled_embedding_bag_patterns(
29492972
_may_generate_pattern_with_dtype_convert(
29502973
get_dequantize_per_tensor_activation_pattern(
2951-
is_tensor_overload=False, is_fp8=True
2974+
is_tensor_overload=False, is_fp8=is_fp8
29522975
),
29532976
KeywordArg("autocast_act_dtype"),
29542977
dtype == torch.bfloat16,
29552978
),
2956-
),
2957-
pass_number=1,
2958-
dtype=dtype,
2959-
) # pass_number=0 to run before weight prepack
2979+
)
2980+
2981+
_register_scaled_embedding_bag_pass(
2982+
embeddingbag_pattern,
2983+
pass_number=1,
2984+
dtype=dtype
2985+
)
2986+
2987+
# will support fp8 output later
2988+
if not is_fp8:
2989+
embeddingbag_with_qoutput_pattern = generate_pattern_with_output_quant(
2990+
embeddingbag_pattern,
2991+
dtype == torch.bfloat16,
2992+
is_fp8,
2993+
)
2994+
2995+
_register_scaled_embedding_bag_pass(
2996+
embeddingbag_with_qoutput_pattern,
2997+
pass_number=0,
2998+
dtype=dtype,
2999+
)
29603000

29613001

29623002
@functools.lru_cache(None)

0 commit comments

Comments
 (0)