Skip to content

Commit 7efa25a

Browse files
committed
try github alerts
1 parent 562df9b commit 7efa25a

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchao_float8/README.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ As a first, and possibly only step, we use the GPT-Fast benchmark provided by To
3030

3131
## Torch Memory Profile
3232

33+
> [!IMPORTANT]
3334
> **TLDR**: The dequantization of weights in FP8WeightsOnly config is not fused with GEMV computations. This leads to spike in GPU VRAM usage.
3435
3536
The 0-th inference iteration is profiled using a CUDA memory snapshot. The snapshots are available at the following paths: `llama_benchmark/Meta-Llama-3.1-8B_None_torch_memory_profiler.pickle`, `llama_benchmark/Meta-Llama-3.1-8B_float8dq-tensor_torch_memory_profiler.pickle`, `llama_benchmark/Meta-Llama-3.1-8B_float8wo_torch_memory_profiler.pickle`.
@@ -47,36 +48,40 @@ The 0-th inference iteration is profiled using a CUDA memory snapshot. The snaps
4748
<p><strong>Figure 3:</strong> FP8 Weights Only Static Quantization Whole Timeline</p>
4849
</div>
4950

51+
> [!NOTE]
5052
> **Initial observation**: On first look, comparing the whole timelines in Figures 1, 2, and 3, we can notice that they all have some blocks of memory in the middle of the timeline (encircled). Also, Float8WO has multiple spikes in memory that are not present in the other two snapshots (marked by arrow).
5153
5254
For diving deeper, we can zoom-in on the blocks and the regions around them, and view their call stack. This reveals that most of the quantization, compilation, and inference activity actually happens in the narrow slice of memory at the top. Further, we can see different phases of inference in the memory timeline. For example, all the encircled rectangular blocks of memory have a call stack related to Torch Dynamo, Inductor, and CUDA Graph Trees, indicating that this memory was active during the compilation of the decode function.
5355

5456
Revisiting the GPT-Fast code in generate.py, we can see there are different phases of inference:
55-
1. Quantization, except for baseline. Model's linear layers are replaced with quantized affine layers in subsequent steps.
56-
2. Compilation of decode's forward call. Torch Dynamo wraps the forward in compile wrapper.
57-
3.1. Dummy inference pass: Prefill with quantized model.
58-
3.2. Dummy inference pass: Decode with frame evaluation by Torch Dynamo, compilation and lowering by Torch Inductor, and subsequent graph recording into a CUDAGraph.
59-
4.1. Real inference passes: Prefill like 3.1.
60-
4.2. Real inference passes: Decode with quantized model and replay of CUDAGraph Trees.
57+
1. Quantization, except for baseline. Model's linear layers are replaced with quantized affine layers in subsequent steps.
58+
2. Compilation of decode's forward call. Torch Dynamo wraps the forward in compile wrapper.
59+
3.1. Dummy inference pass: Prefill with quantized model.
60+
3.2. Dummy inference pass: Decode with frame evaluation by Torch Dynamo, compilation and lowering by Torch Inductor, and subsequent graph recording into a CUDAGraph.
61+
4.1. Real inference passes: Prefill like 3.1.
62+
4.2. Real inference passes: Decode with quantized model and replay of CUDAGraph Trees.
6163

62-
Note: Both decode are also followed by cloning of outputs for optional post-processing which takes up minor amount of memory.
64+
Note: Both decode stages, 3.2 and 4.2, are also followed by cloning of outputs for optional logits post-processing which takes up minor amount of memory.
6365

6466
When we try to map these phases to the memory timeline from PyTorch, we can see presence of 3.1, 3.2, 4.1, and 4.2. Figure 4 shows the phases by using the float8 weights only config as an example. The bug of excessive memory usage during inference actually helps us to see the phases more distinctly with the weights only config.
6567

68+
> [!NOTE]
6669
> **Second Observation**: We can see that the spikes in memory happen for 3.1, 3.2, and 4.1, but not for 4.2. 4.2 is decode with CUDAGraph being used, which forces the same memory locations to be used for graph replay. This must have avoided the memory spikes for 4.2. We can thus conclude that in any future inference pass, prefill-stage of inference will the one responsible for CUDA OOM bugs, or at least spikes in the memory.
6770
6871
<div align="center">
6972
<img src="figures/float8wo_marked_with_phases.png" alt="FP8 Weights Only Static Quantization's Timeline Annotated with Phases of Inference" width="800">
7073
<p><strong>Figure 4:</strong> FP8 Weights Only Static Quantization's Timeline Annotated with Phases of Inference</p>
7174
</div>
7275

76+
> [!NOTE]
7377
> **Third Observation**: Looking at the call stack of the spikes in memory, the spikes can be attributed to the dequantization of the float8 weight tensors during the forward pass in TorchAO. In some cases, the spikes can even be a GB in size. Comparing this with our other Float8 config which uses static weights quantization and dynamic activations quantization, the spikes are absent. There are miniscule memory increases during the dynamic quantization of the activation tensors, but there aren't any memory allocations visible for dequantization of the weights.
7478
7579
<div align="center">
7680
<img src="figures/float8wo_dequantize_callstack.png" alt="Call Stack of the Spikes in Memory for FP8 Weights Only Static Quantization" width="800">
7781
<p><strong>Figure 5:</strong> Call Stack of the Spikes in Memory for FP8 Weights Only Static Quantization</p>
7882
</div>
7983

84+
> [!NOTE]
8085
> **Takeaway**: The FP8 Weights Only Config is not doing dequantization of weights tensor properly. Dequantization of the weights and the computation of product of weights and activations should be fused in the GEMV kernel instead of being done separately to avoid such memory spikes in the GPU VRAM.
8186
8287
## Torch Execution Trace

0 commit comments

Comments
 (0)