@@ -262,7 +262,6 @@ def create_weights(
262262 layer .input_size_per_partition = input_size_per_partition
263263 layer .output_size_per_partition = output_size_per_partition
264264 layer .orig_dtype = params_dtype
265- layer .executed_weight_requant_ue8m0 = False
266265
267266 # WEIGHT
268267 weight_dtype = (
@@ -300,6 +299,7 @@ def create_weights(
300299 output_dim = 0 ,
301300 weight_loader = weight_loader ,
302301 )
302+ scale .format_ue8m0 = False
303303 scale [:] = torch .finfo (torch .float32 ).min
304304 layer .register_parameter ("weight_scale_inv" , scale )
305305 else :
@@ -367,14 +367,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
367367 self .w8a8_block_fp8_linear
368368 is deepgemm_w8a8_block_fp8_linear_with_fallback
369369 )
370- and (not layer .executed_weight_requant_ue8m0 )
370+ and (not layer .weight_scale_inv . format_ue8m0 )
371371 ):
372372 requant_weight_ue8m0_inplace (
373373 layer .weight ,
374374 layer .weight_scale_inv ,
375375 self .quant_config .weight_block_size ,
376376 )
377- layer .executed_weight_requant_ue8m0 = True
377+ layer .weight_scale_inv . format_ue8m0 = True
378378 weight , weight_scale = layer .weight .data , layer .weight_scale_inv .data
379379
380380 layer .weight .data = weight .data
0 commit comments