44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , Optional
7+ from typing import Any , Optional , Tuple
88
99import torch
1010import torch .nn .functional as F
@@ -196,15 +196,40 @@ def convert(
196196 """
197197 self ._convert_helper (model )
198198 return model
199+
200+ @staticmethod
201+ def quantize_weights (
202+ weight : torch .Tensor ,
203+ bit_width : int ,
204+ group_size : int ,
205+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
206+ """
207+ Helper function to quantize weights
208+ """
209+ (qmin , qmax ) = _get_qmin_qmax (bit_width )
210+ (s , zp ) = get_group_qparams_symmetric (
211+ weight , bit_width , group_size
212+ )
213+ from torchao ._executorch_ops import (
214+ _quantized_decomposed_quantize_per_channel_group_wrapper ,
215+ )
216+ q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
217+ weight ,
218+ s ,
219+ zp ,
220+ qmin ,
221+ qmax ,
222+ torch .int8 ,
223+ group_size ,
224+ )
225+ return (q_weight , s , zp )
226+
199227
200228 def _convert_helper (self , module : torch .nn .Module ):
201229 """
202230 Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
203231 modules with `Int4WeightOnlyEmbedding`
204232 """
205- from torchao ._executorch_ops import (
206- _quantized_decomposed_quantize_per_channel_group_wrapper ,
207- )
208233
209234 for name , child in module .named_children ():
210235 if isinstance (child , Int4WeightOnlyQATEmbedding ):
@@ -230,20 +255,8 @@ def _convert_helper(self, module: torch.nn.Module):
230255 )
231256 setattr (module , name , quantized_embedding )
232257
258+ q_weight , s , zp = self .quantize_weights (child .weight , self .bit_width , group_size )
233259 # Load weights and qparams into quantized embedding
234- (qmin , qmax ) = _get_qmin_qmax (self .bit_width )
235- (s , zp ) = get_group_qparams_symmetric (
236- child .weight , self .bit_width , group_size
237- )
238- q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
239- child .weight ,
240- s ,
241- zp ,
242- qmin ,
243- qmax ,
244- torch .int8 ,
245- group_size ,
246- )
247260 quantized_embedding .weight = q_weight
248261 quantized_embedding .scale = s .to (scale_precision )
249262 quantized_embedding .zero_point = zp .to (zero_point_precision )
0 commit comments