@@ -2860,7 +2860,11 @@ def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float3
28602860 def scaled_embedding_bag (match : Match , * args , ** kwargs ):
28612861 assert dtype in [torch .float32 , torch .bfloat16 ]
28622862
2863- getitem_node = match .output_node ()
2863+ if "o_dtype" in kwargs :
2864+ quant_node = match .output_node ()
2865+ getitem_node = quant_node .args [0 ]
2866+ else :
2867+ getitem_node = match .output_node ()
28642868 embedding_bag_node = getitem_node .args [0 ]
28652869 assert embedding_bag_node .target is aten ._embedding_bag_forward_only .default
28662870
@@ -2889,11 +2893,25 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
28892893 kwargs ["mode" ],
28902894 kwargs ["include_last_offset" ],
28912895 )
2892- # only support fp32 output, next step to support more dtype
2896+ # only support fp32 and int8 output on kernel
2897+ # next step to support more output_type
2898+ output_type = torch .float
28932899 o_scale = 1.0
2900+ if "o_dtype" in kwargs :
2901+ output_type = torch .int8
2902+ o_scale = kwargs ["o_inv_scale" ]
28942903
28952904 graph = match .graph
28962905 with graph .inserting_before (getitem_node ):
2906+ # scale type is float on int8 q/dq
2907+ # Not support float scale yet on scaled_embedding_bag
2908+ # convert scale from float into tensor
2909+ if output_type == torch .int8 :
2910+ w_scale = graph .call_function (
2911+ torch .ops .aten .full .default ,
2912+ args = ([1 ], w_scale ),
2913+ kwargs = {"dtype" : torch .float },
2914+ )
28972915 new_args : tuple [Any , ...] = (
28982916 qw ,
28992917 indices ,
@@ -2902,13 +2920,18 @@ def scaled_embedding_bag(match: Match, *args, **kwargs):
29022920 o_scale ,
29032921 mode ,
29042922 include_last_offset ,
2905- torch . float ,
2923+ output_type ,
29062924 )
29072925
29082926 new_embedding_bag_node = graph .call_function (
29092927 torch .ops .torchao ._scaled_embedding_bag .default , args = new_args
29102928 )
29112929
2930+ # remove quant node
2931+ if output_type == torch .int8 :
2932+ quant_node .replace_all_uses_with (getitem_node )
2933+ getitem_node .meta .update (quant_node .meta )
2934+ graph .erase_node (quant_node )
29122935 getitem_node .replace_all_uses_with (new_embedding_bag_node )
29132936 new_embedding_bag_node .meta .update (embedding_bag_node .meta )
29142937
@@ -2943,20 +2966,37 @@ def _generate_scaled_embedding_bag_patterns(dq_pattern):
29432966
29442967
29452968def _register_quantization_embeddingbag_pass ():
2946- for dtype in [torch . float32 , torch . bfloat16 ]:
2947- _register_scaled_embedding_bag_pass (
2948- _generate_scaled_embedding_bag_patterns (
2969+ for is_fp8 in [True , False ]:
2970+ for dtype in [ torch . float32 , torch . bfloat16 ]:
2971+ embeddingbag_pattern = _generate_scaled_embedding_bag_patterns (
29492972 _may_generate_pattern_with_dtype_convert (
29502973 get_dequantize_per_tensor_activation_pattern (
2951- is_tensor_overload = False , is_fp8 = True
2974+ is_tensor_overload = False , is_fp8 = is_fp8
29522975 ),
29532976 KeywordArg ("autocast_act_dtype" ),
29542977 dtype == torch .bfloat16 ,
29552978 ),
2956- ),
2957- pass_number = 1 ,
2958- dtype = dtype ,
2959- ) # pass_number=0 to run before weight prepack
2979+ )
2980+
2981+ _register_scaled_embedding_bag_pass (
2982+ embeddingbag_pattern ,
2983+ pass_number = 1 ,
2984+ dtype = dtype
2985+ )
2986+
2987+ # will support fp8 output later
2988+ if not is_fp8 :
2989+ embeddingbag_with_qoutput_pattern = generate_pattern_with_output_quant (
2990+ embeddingbag_pattern ,
2991+ dtype == torch .bfloat16 ,
2992+ is_fp8 ,
2993+ )
2994+
2995+ _register_scaled_embedding_bag_pass (
2996+ embeddingbag_with_qoutput_pattern ,
2997+ pass_number = 0 ,
2998+ dtype = dtype ,
2999+ )
29603000
29613001
29623002@functools .lru_cache (None )
0 commit comments