File tree Expand file tree Collapse file tree 3 files changed +5
-1
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 3 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -230,7 +230,7 @@ Note: the accuracy results below are WIP and are not optimized yet.
230230| recipe | wikitext word_perplexity | winogrande |
231231| ------ | -------- | ---------- |
232232| bfloat16 (baseline) | 7.5472105433748435 | 0.7426992896606156 |
233- | mxfp8 | 7.609070006132819 | 0.7292817679558011 |
233+ | mxfp8 | 7.605192917647689 | 0.7355958958168903 |
234234| nvfp4 | 8.44478255417328 | 0.7182320441988951 |
235235
236236To reproduce:
Original file line number Diff line number Diff line change @@ -85,6 +85,7 @@ def _mx_inference_linear_transform(
8585 block_size = config .block_size ,
8686 kernel_preference = config .kernel_preference ,
8787 is_swizzled_scales = True ,
88+ scaling_mode = ScaleCalculationMode .RCEIL ,
8889 )
8990
9091 # Convert weight to MX Tensor
@@ -95,6 +96,7 @@ def _mx_inference_linear_transform(
9596 kernel_preference = config .kernel_preference ,
9697 act_quant_kwargs = act_quant_kwargs ,
9798 is_swizzled_scales = True ,
99+ scaling_mode = ScaleCalculationMode .RCEIL ,
98100 )
99101
100102 module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
Original file line number Diff line number Diff line change 8787class QuantizeTensorToMXKwargs (QuantizeTensorKwargs ):
8888 elem_dtype : Union [torch .dtype , str ] = torch .float8_e4m3fn
8989 block_size : int = 32
90+ # TODO(future PR): flip the scaling_mode default to RCEIL
9091 scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR
9192 kernel_preference : KernelPreference = KernelPreference .EMULATED
9293 is_swizzled_scales : bool = False
@@ -533,6 +534,7 @@ def to_mx(
533534 data_hp : torch .Tensor ,
534535 elem_dtype : Union [torch .dtype , str ],
535536 block_size : int = BLOCK_SIZE_DEFAULT ,
537+ # TODO(future PR): flip the scaling_mode default to RCEIL
536538 scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR ,
537539 # TODO(future PR): switch default gemm to cublas
538540 kernel_preference : KernelPreference = KernelPreference .EMULATED ,
You can’t perform that action at this time.
0 commit comments