@@ -3047,42 +3047,60 @@ def test_fp8_q_attention_block(self):
30473047 annotate_matmul = annotate_matmul , is_fp8 = True
30483048 )
30493049
3050- @skipIfNoDynamoSupport
3051- @skipIfNoONEDNN
3052- @skipIfNoFloat8Support
3053- @unittest .skipIf (
3054- "CPU" not in torch ._C ._dispatch_dump ("torchao::_scaled_embedding_bag" ),
3055- reason = "cpp kernels not built" ,
3056- )
3057- def test_fp8_scaled_embedding_bag (self ):
3058- dtype = torch .float8_e4m3fn
3059-
3050+ def _test_scaled_embedding_bag_helper (self , dtype , with_quant = False ):
30603051 class FP8QDQEmbeddingBag (torch .nn .Module ):
30613052 def __init__ (self ):
30623053 super ().__init__ ()
30633054 self .weight_scale = 2.0
3055+ self .output_scale = 3.0
3056+
3057+ def _dq (self , weight ):
3058+ if dtype == torch .float8_e4m3fn :
3059+ res = torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
3060+ tensor = weight .data ,
3061+ scale = torch .tensor ([self .weight_scale ]),
3062+ output_dtype = torch .float ,
3063+ )
3064+ else :
3065+ res = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
3066+ weight .data ,
3067+ self .weight_scale ,
3068+ 0 ,
3069+ - 128 , 127 , torch .int8 ,
3070+ )
3071+ return res
3072+
3073+ def _q (self , x ):
3074+ if dtype == torch .float8_e4m3fn :
3075+ qx = torch .ops .torchao .quantize_affine_float8_non_decomposed .default (
3076+ tensor = x ,
3077+ scale = torch .tensor ([self .output_scale ]),
3078+ float8_dtype = dtype ,
3079+ )
3080+ else :
3081+ qx = torch .ops .quantized_decomposed .quantize_per_tensor .default (
3082+ x , self .output_scale , 0 , - 128 , 127 , torch .int8
3083+ )
3084+ return qx
30643085
30653086 def forward (
30663087 self ,
30673088 weight ,
30683089 input ,
30693090 offsets = None ,
30703091 ):
3071- weight = (
3072- torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
3073- tensor = weight .data ,
3074- scale = torch .tensor ([self .weight_scale ]),
3075- output_dtype = torch .float ,
3076- )
3077- )
3092+ weight = self ._dq (weight )
30783093
3079- return torch .nn .functional .embedding_bag (
3094+ res = torch .nn .functional .embedding_bag (
30803095 input ,
30813096 weight ,
30823097 offsets ,
30833098 mode = "sum" ,
30843099 include_last_offset = True ,
30853100 )
3101+ if with_quant :
3102+ res = self ._q (res )
3103+ return res
30863104
30873105 EMBEDINGBAG_MULTIHOT_SIZES = [1 , 2 , 3 , 10 ]
30883106 EMBEDINGBAG_BAG_SIZES = [1 , 2 , 128 , 1024 ]
@@ -3109,8 +3127,11 @@ def forward(
31093127 )
31103128
31113129 def matcher_check_fn ():
3130+ counter_name = "scaled_embedding_bag"
3131+ if with_quant :
3132+ counter_name += "_with_quant"
31123133 self .assertEqual (
3113- counters ["inductor" ]["scaled_embedding_bag_matcher_count " ], 1
3134+ counters ["inductor" ][f" { counter_name } _matcher_count " ], 1
31143135 )
31153136
31163137 self ._test_common (
@@ -3120,6 +3141,38 @@ def matcher_check_fn():
31203141 )
31213142
31223143
3144+ @skipIfNoDynamoSupport
3145+ @skipIfNoONEDNN
3146+ @skipIfNoFloat8Support
3147+ @unittest .skipIf (
3148+ "CPU" not in torch ._C ._dispatch_dump ("torchao::_scaled_embedding_bag" ),
3149+ reason = "cpp kernels not built" ,
3150+ )
3151+ def test_fp8_scaled_embedding_bag (self ):
3152+ self ._test_scaled_embedding_bag_helper (torch .float8_e4m3fn )
3153+
3154+ @skipIfNoDynamoSupport
3155+ @skipIfNoONEDNN
3156+ @skipIfNoFloat8Support
3157+ @unittest .skipIf (
3158+ "CPU" not in torch ._C ._dispatch_dump ("torchao::_scaled_embedding_bag" ),
3159+ reason = "cpp kernels not built" ,
3160+ )
3161+ def test_int8_scaled_embedding_bag (self ):
3162+ self ._test_scaled_embedding_bag_helper (torch .int8 )
3163+
3164+
3165+ @skipIfNoDynamoSupport
3166+ @skipIfNoONEDNN
3167+ @skipIfNoFloat8Support
3168+ @unittest .skipIf (
3169+ "CPU" not in torch ._C ._dispatch_dump ("torchao::_scaled_embedding_bag" ),
3170+ reason = "cpp kernels not built" ,
3171+ )
3172+ def test_int8_scaled_embedding_bag_with_quant (self ):
3173+ self ._test_scaled_embedding_bag_helper (torch .int8 , True )
3174+
3175+
31233176instantiate_parametrized_tests (TestPatternMatcher )
31243177if __name__ == "__main__" :
31253178 if IS_LINUX and HAS_CPU and torch .backends .mkldnn .is_available ():
0 commit comments