Hi,
We are testing FP8 weight caching in Transformer Engine by passing is_first_microbatch=True/False during gradient accumulation. We are using the Float8BlockScaling recipe and we get two unexpected behaviors:
- Enabling FP8 weight caching does not improve training throughput in our setup. In some runs it even makes training slower.
- Enabling FP8 weight caching changes the training curve at some steps. Under
Float8BlockScaling, we expected the runs with and without caching to be numerically identical, or at least fully aligned in practice.
For Float8BlockScaling, our understanding from the implementation is that quantization is blockwise and does not rely on delayed amax_history state like DelayedScaling.
Because of that, we expected:
- FP8 weight caching to provide some speedup when the same frozen weights are reused across microbatches.
- Runs with and without caching to remain fully aligned under
Float8BlockScaling, since the weights are unchanged within a gradient accumulation cycle.
Questions:
- Is FP8 weight caching expected to provide a measurable speedup under Float8BlockScaling?
Is there any known case where FP8 weight caching can be neutral or even slower with Float8BlockScaling?
- Under Float8BlockScaling, should runs with and without FP8 weight caching be bitwise identical, or at least numerically aligned step by step?
- If differences are expected, what is the source of the difference for Float8BlockScaling, given that it does not appear to use delayed amax_history state?
The current documentation warning about FP8 weight caching causing non-bitwise-identical outputs seems understandable for delayed-scaling recipes, but we are unsure whether that warning is also intended to apply to Float8BlockScaling.
Thank you for any info or advice in advance!
Hi,
We are testing FP8 weight caching in Transformer Engine by passing
is_first_microbatch=True/Falseduring gradient accumulation. We are using theFloat8BlockScalingrecipe and we get two unexpected behaviors:Float8BlockScaling, we expected the runs with and without caching to be numerically identical, or at least fully aligned in practice.For
Float8BlockScaling, our understanding from the implementation is that quantization is blockwise and does not rely on delayedamax_historystate likeDelayedScaling.Because of that, we expected:
Float8BlockScaling, since the weights are unchanged within a gradient accumulation cycle.Questions:
Is there any known case where FP8 weight caching can be neutral or even slower with Float8BlockScaling?
The current documentation warning about FP8 weight caching causing non-bitwise-identical outputs seems understandable for delayed-scaling recipes, but we are unsure whether that warning is also intended to apply to Float8BlockScaling.
Thank you for any info or advice in advance!