Skip to content

Commit 534bea5

Browse files
authored
flip mx inference scaling setting to RCEIL (#3428)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent ca2132e commit 534bea5

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

torchao/prototype/mx_formats/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ Note: the accuracy results below are WIP and are not optimized yet.
230230
| recipe | wikitext word_perplexity | winogrande |
231231
| ------ | -------- | ---------- |
232232
| bfloat16 (baseline) | 7.5472105433748435 | 0.7426992896606156 |
233-
| mxfp8 | 7.609070006132819 | 0.7292817679558011 |
233+
| mxfp8 | 7.605192917647689 | 0.7355958958168903 |
234234
| nvfp4 | 8.44478255417328 | 0.7182320441988951 |
235235

236236
To reproduce:

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def _mx_inference_linear_transform(
8585
block_size=config.block_size,
8686
kernel_preference=config.kernel_preference,
8787
is_swizzled_scales=True,
88+
scaling_mode=ScaleCalculationMode.RCEIL,
8889
)
8990

9091
# Convert weight to MX Tensor
@@ -95,6 +96,7 @@ def _mx_inference_linear_transform(
9596
kernel_preference=config.kernel_preference,
9697
act_quant_kwargs=act_quant_kwargs,
9798
is_swizzled_scales=True,
99+
scaling_mode=ScaleCalculationMode.RCEIL,
98100
)
99101

100102
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
class QuantizeTensorToMXKwargs(QuantizeTensorKwargs):
8888
elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn
8989
block_size: int = 32
90+
# TODO(future PR): flip the scaling_mode default to RCEIL
9091
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR
9192
kernel_preference: KernelPreference = KernelPreference.EMULATED
9293
is_swizzled_scales: bool = False
@@ -533,6 +534,7 @@ def to_mx(
533534
data_hp: torch.Tensor,
534535
elem_dtype: Union[torch.dtype, str],
535536
block_size: int = BLOCK_SIZE_DEFAULT,
537+
# TODO(future PR): flip the scaling_mode default to RCEIL
536538
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
537539
# TODO(future PR): switch default gemm to cublas
538540
kernel_preference: KernelPreference = KernelPreference.EMULATED,

0 commit comments

Comments
 (0)