Skip to content

Commit 415baea

Browse files
committed
Add capture cudagraph tqdm
1 parent 433af6f commit 415baea

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

lightllm/common/basemodel/cuda_graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import bisect
55
from typing import Optional
6+
from tqdm import tqdm
67
from lightllm.utils.log_utils import init_logger
78
from lightllm.utils.envs_utils import get_env_start_args
89
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
@@ -191,7 +192,12 @@ def warmup(self, model):
191192
model: TpPartBaseModel = model
192193

193194
# decode cuda graph init
194-
for batch_size in self.cuda_graph_batch_sizes[::-1]:
195+
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
196+
for batch_size in progress_bar:
197+
# Get available memory info
198+
avail_mem, total_mem = torch.cuda.mem_get_info()
199+
avail_mem_gb = avail_mem / (1024**3)
200+
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
195201
seq_len = 2
196202
total_token_num = batch_size * seq_len
197203
max_len_in_batch = self.graph_max_len_in_batch
@@ -246,7 +252,12 @@ def warmup_overlap(self, model):
246252

247253
model: TpPartBaseModel = model
248254

249-
for batch_size in self.cuda_graph_batch_sizes[::-1]:
255+
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
256+
for batch_size in progress_bar:
257+
# Get available memory info
258+
avail_mem, total_mem = torch.cuda.mem_get_info()
259+
avail_mem_gb = avail_mem / (1024**3)
260+
progress_bar.set_description(f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
250261
decode_batches = []
251262
for micro_batch_index in [0, 1]:
252263
# dummy decoding, capture the cudagraph

0 commit comments

Comments
 (0)