You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao_float8/README.md
+12-7Lines changed: 12 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -30,6 +30,7 @@ As a first, and possibly only step, we use the GPT-Fast benchmark provided by To
30
30
31
31
## Torch Memory Profile
32
32
33
+
> [!IMPORTANT]
33
34
> **TLDR**: The dequantization of weights in FP8WeightsOnly config is not fused with GEMV computations. This leads to spike in GPU VRAM usage.
34
35
35
36
The 0-th inference iteration is profiled using a CUDA memory snapshot. The snapshots are available at the following paths: `llama_benchmark/Meta-Llama-3.1-8B_None_torch_memory_profiler.pickle`, `llama_benchmark/Meta-Llama-3.1-8B_float8dq-tensor_torch_memory_profiler.pickle`, `llama_benchmark/Meta-Llama-3.1-8B_float8wo_torch_memory_profiler.pickle`.
@@ -47,36 +48,40 @@ The 0-th inference iteration is profiled using a CUDA memory snapshot. The snaps
47
48
<p><strong>Figure 3:</strong> FP8 Weights Only Static Quantization Whole Timeline</p>
48
49
</div>
49
50
51
+
> [!NOTE]
50
52
> **Initial observation**: On first look, comparing the whole timelines in Figures 1, 2, and 3, we can notice that they all have some blocks of memory in the middle of the timeline (encircled). Also, Float8WO has multiple spikes in memory that are not present in the other two snapshots (marked by arrow).
51
53
52
54
For diving deeper, we can zoom-in on the blocks and the regions around them, and view their call stack. This reveals that most of the quantization, compilation, and inference activity actually happens in the narrow slice of memory at the top. Further, we can see different phases of inference in the memory timeline. For example, all the encircled rectangular blocks of memory have a call stack related to Torch Dynamo, Inductor, and CUDA Graph Trees, indicating that this memory was active during the compilation of the decode function.
53
55
54
56
Revisiting the GPT-Fast code in generate.py, we can see there are different phases of inference:
55
-
1. Quantization, except for baseline. Model's linear layers are replaced with quantized affine layers in subsequent steps.
56
-
2. Compilation of decode's forward call. Torch Dynamo wraps the forward in compile wrapper.
57
-
3.1. Dummy inference pass: Prefill with quantized model.
58
-
3.2. Dummy inference pass: Decode with frame evaluation by Torch Dynamo, compilation and lowering by Torch Inductor, and subsequent graph recording into a CUDAGraph.
59
-
4.1. Real inference passes: Prefill like 3.1.
60
-
4.2. Real inference passes: Decode with quantized model and replay of CUDAGraph Trees.
57
+
1. Quantization, except for baseline. Model's linear layers are replaced with quantized affine layers in subsequent steps.
58
+
2. Compilation of decode's forward call. Torch Dynamo wraps the forward in compile wrapper.
59
+
3.1. Dummy inference pass: Prefill with quantized model.
60
+
3.2. Dummy inference pass: Decode with frame evaluation by Torch Dynamo, compilation and lowering by Torch Inductor, and subsequent graph recording into a CUDAGraph.
61
+
4.1. Real inference passes: Prefill like 3.1.
62
+
4.2. Real inference passes: Decode with quantized model and replay of CUDAGraph Trees.
61
63
62
-
Note: Both decode are also followed by cloning of outputs for optional post-processing which takes up minor amount of memory.
64
+
Note: Both decode stages, 3.2 and 4.2, are also followed by cloning of outputs for optional logits post-processing which takes up minor amount of memory.
63
65
64
66
When we try to map these phases to the memory timeline from PyTorch, we can see presence of 3.1, 3.2, 4.1, and 4.2. Figure 4 shows the phases by using the float8 weights only config as an example. The bug of excessive memory usage during inference actually helps us to see the phases more distinctly with the weights only config.
65
67
68
+
> [!NOTE]
66
69
> **Second Observation**: We can see that the spikes in memory happen for 3.1, 3.2, and 4.1, but not for 4.2. 4.2 is decode with CUDAGraph being used, which forces the same memory locations to be used for graph replay. This must have avoided the memory spikes for 4.2. We can thus conclude that in any future inference pass, prefill-stage of inference will the one responsible for CUDA OOM bugs, or at least spikes in the memory.
67
70
68
71
<divalign="center">
69
72
<imgsrc="figures/float8wo_marked_with_phases.png"alt="FP8 Weights Only Static Quantization's Timeline Annotated with Phases of Inference"width="800">
70
73
<p><strong>Figure 4:</strong> FP8 Weights Only Static Quantization's Timeline Annotated with Phases of Inference</p>
71
74
</div>
72
75
76
+
> [!NOTE]
73
77
> **Third Observation**: Looking at the call stack of the spikes in memory, the spikes can be attributed to the dequantization of the float8 weight tensors during the forward pass in TorchAO. In some cases, the spikes can even be a GB in size. Comparing this with our other Float8 config which uses static weights quantization and dynamic activations quantization, the spikes are absent. There are miniscule memory increases during the dynamic quantization of the activation tensors, but there aren't any memory allocations visible for dequantization of the weights.
74
78
75
79
<divalign="center">
76
80
<imgsrc="figures/float8wo_dequantize_callstack.png"alt="Call Stack of the Spikes in Memory for FP8 Weights Only Static Quantization"width="800">
77
81
<p><strong>Figure 5:</strong> Call Stack of the Spikes in Memory for FP8 Weights Only Static Quantization</p>
78
82
</div>
79
83
84
+
> [!NOTE]
80
85
> **Takeaway**: The FP8 Weights Only Config is not doing dequantization of weights tensor properly. Dequantization of the weights and the computation of product of weights and activations should be fused in the GEMV kernel instead of being done separately to avoid such memory spikes in the GPU VRAM.
0 commit comments