Skip to content

Commit b9d33dc

Browse files
committed
[Enhance] enhance logging (mostly on token stats)
1. use reduced tokens stats for exp_tgs; 2. rename some variables (drop the reduced_ prefix, use total_ instead for clarity); 3. drop maxvio stats in loss_log as it is already covered in internal metrics.
1 parent 0f9f751 commit b9d33dc

3 files changed

Lines changed: 40 additions & 41 deletions

File tree

xtuner/v1/engine/train_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class LossLog(TypedDict):
5353
class OtherLog(TypedDict):
5454
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
5555
maxvio: NotRequired[float]
56-
consumed_tokens: float
56+
consumed_tokens: int
5757
extra_info: ModelForwardExtraLogInfo
5858
efficient_attn_ratio: float
5959

@@ -252,7 +252,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
252252
step_llm_loss = torch.tensor(0.0, device=DEVICE)
253253
step_balancing_loss: torch.Tensor | None = None
254254
step_z_loss: torch.Tensor | None = None
255-
step_consumed_tokens = torch.tensor(0.0, device=DEVICE)
255+
step_consumed_tokens = torch.tensor(0, device=DEVICE)
256256

257257
if self._count == 0:
258258
logger.info(f"grad_accumulation_steps: {iters_per_step}")
@@ -346,7 +346,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
346346
reduced_z_loss = step_z_loss
347347
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
348348
loss_log["reduced_z_loss"] = reduced_z_loss.item()
349-
other_log["consumed_tokens"] = step_consumed_tokens.item()
349+
other_log["consumed_tokens"] = cast(int, step_consumed_tokens.item())
350350
other_log["extra_info"] = train_engine_extra_info
351351
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
352352
return loss_log, other_log

xtuner/v1/engine/vision_compose_train_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
183183
step_llm_loss = torch.tensor(0.0, device=DEVICE)
184184
step_balancing_loss: torch.Tensor | None = None
185185
step_z_loss: torch.Tensor | None = None
186-
step_consumed_tokens = torch.tensor(0.0, device=DEVICE)
186+
step_consumed_tokens = torch.tensor(0, device=DEVICE)
187187
efficient_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long)
188188
total_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long)
189189

@@ -257,7 +257,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
257257
reduced_z_loss = step_z_loss
258258
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
259259
loss_log["reduced_z_loss"] = reduced_z_loss.item()
260-
other_log["consumed_tokens"] = step_consumed_tokens.item()
260+
other_log["consumed_tokens"] = cast(int, step_consumed_tokens.item())
261261
other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment]
262262
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
263263
return loss_log, other_log

xtuner/v1/train/trainer.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,9 @@ def __init__(
514514
self._debug = debug
515515
self._seed = seed
516516

517-
self._reduced_consumed_tokens = 0
517+
self._total_consumed_tokens = 0
518518
self._exp_consumed_tokens = 0
519-
self._reduced_consumed_samples = 0
519+
self._total_consumed_samples = 0
520520

521521
self._train_time = 0
522522
self._train_time_offset = 0
@@ -762,26 +762,24 @@ def fit(self):
762762
extra_info_dict = extra_info_updated.get()
763763
loss_log.update(extra_info_dict)
764764

765-
if "maxvio" in other_log:
766-
loss_log["maxvio"] = other_log["maxvio"]
767765
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"]
768766

769767
internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)
770768

771769
self._cur_step += 1
772770

773771
reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
774-
self._reduced_consumed_tokens += reduced_step_consumed_tokens
775-
776-
self._exp_consumed_tokens += step_consumed_tokens
772+
self._total_consumed_tokens += reduced_step_consumed_tokens
773+
self._exp_consumed_tokens += reduced_step_consumed_tokens
777774
self._train_time = time_after_train_step - train_begin
778775

779776
# TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
780777
self._log_step(
781778
loss_log=loss_log,
782-
step_consumed_tokens=step_consumed_tokens,
779+
local_step_consumed_tokens=step_consumed_tokens,
780+
step_consumed_tokens=reduced_step_consumed_tokens,
783781
exp_consumed_tokens=self._exp_consumed_tokens,
784-
reduced_consumed_tokens=self._reduced_consumed_tokens,
782+
total_consumed_tokens=self._total_consumed_tokens,
785783
data_time=data_time,
786784
step_time=step_time,
787785
train_time=self._train_time,
@@ -1137,8 +1135,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11371135
{
11381136
"cur_step": self.cur_step,
11391137
"cur_epoch": self._cur_epoch,
1140-
"reduced_consumed_samples": self._reduced_consumed_samples,
1141-
"reduced_consumed_tokens": self._reduced_consumed_tokens,
1138+
"total_consumed_samples": self._total_consumed_samples,
1139+
"total_consumed_tokens": self._total_consumed_tokens,
11421140
"train_time_offset": self._train_time + self._train_time_offset,
11431141
}
11441142
)
@@ -1150,8 +1148,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11501148
ckp_list.append(str(checkpoint_path))
11511149
current_exp.cur_step = self.cur_step
11521150
current_exp.cur_epoch = self._cur_epoch
1153-
current_exp.consumed_samples = int(self._reduced_consumed_samples)
1154-
current_exp.consumed_tokens = int(self._reduced_consumed_tokens)
1151+
current_exp.consumed_samples = int(self._total_consumed_samples)
1152+
current_exp.consumed_tokens = int(self._total_consumed_tokens)
11551153
current_exp.history[-1]["end"] = self.cur_step
11561154

11571155
# Delete checkpoints and update meta's checkpoint_list
@@ -1188,7 +1186,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11881186

11891187
def _save_dataloader(self, dataloader_path: Path | str):
11901188
if self.rank == 0:
1191-
dataloader_state = self._dataloader.get_state_dict(self._reduced_consumed_samples)
1189+
dataloader_state = self._dataloader.get_state_dict(self._total_consumed_samples)
11921190
torch.save(dataloader_state, dataloader_path)
11931191

11941192
@property
@@ -1243,7 +1241,7 @@ def _data_iter(self):
12431241
data_iter = iter(self._dataloader)
12441242
data = next(data_iter)
12451243

1246-
self._reduced_consumed_samples += self._reduce_number_across_rank(len(data))
1244+
self._total_consumed_samples += self._reduce_number_across_rank(len(data))
12471245
yield data
12481246

12491247
def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path:
@@ -1413,10 +1411,11 @@ def _maybe_profiling(self):
14131411

14141412
def _log_step(
14151413
self,
1416-
loss_log: dict,
1414+
loss_log: LossLog,
1415+
local_step_consumed_tokens: int,
14171416
step_consumed_tokens: int,
14181417
exp_consumed_tokens: int,
1419-
reduced_consumed_tokens: int,
1418+
total_consumed_tokens: int,
14201419
data_time: float,
14211420
step_time: float,
14221421
train_time: float,
@@ -1426,20 +1425,20 @@ def _log_step(
14261425
):
14271426
"""Log the training step information."""
14281427
e2e_train_time = train_time + train_time_offset
1429-
tgs = step_consumed_tokens / step_time
1430-
rank_consumed_tokens = reduced_consumed_tokens / self.world_size
1431-
e2e_tgs = rank_consumed_tokens / e2e_train_time
1432-
exp_tgs = exp_consumed_tokens / train_time
1428+
total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size
1429+
exp_consumed_tokens_per_rank = exp_consumed_tokens / self.world_size
1430+
1431+
tgs = local_step_consumed_tokens / step_time
1432+
e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time
1433+
exp_tgs = exp_consumed_tokens_per_rank / train_time
14331434
lr = self._lr_scheduler.get_last_lr()[0]
14341435

14351436
remaining_steps = self.total_step - self.cur_step
1436-
avg_tokens_per_step = rank_consumed_tokens / self.cur_step
1437+
avg_tokens_per_step = total_consumed_tokens_per_rank / self.cur_step
14371438
remaining_tokens = remaining_steps * avg_tokens_per_step
14381439
eta_seconds = remaining_tokens / (tgs + 1e-12)
14391440
eta_hms = str(timedelta(seconds=int(eta_seconds)))
14401441

1441-
est_global_batch_tokens = self.data_mesh["dp"].size() * step_consumed_tokens
1442-
14431442
loss_log_list = [f"{k}: {v:.8f}" for k, v in loss_log.items()]
14441443
loss_log_str = ", ".join(loss_log_list)
14451444

@@ -1453,16 +1452,16 @@ def _log_step(
14531452
self.logger.info(
14541453
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} "
14551454
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
1456-
f"text_tokens: {step_consumed_tokens} "
1457-
f"reduced_consumed_tokens: {reduced_consumed_tokens} "
1455+
f"text_tokens: {local_step_consumed_tokens} "
1456+
f"step_consumed_tokens: {step_consumed_tokens} "
1457+
f"total_consumed_tokens: {total_consumed_tokens} "
14581458
f"{loss_log_str} "
14591459
f"grad_norm: {grad_norm:.8f} "
14601460
f"max_memory: {max_memory / (1024**3):.2f} GB "
14611461
f"reserved_memory: {reserved_memory / (1024**3):.2f} GB "
14621462
f"tgs: {tgs:.1f} "
14631463
f"exp_tgs: {exp_tgs: .1f} "
14641464
f"e2e_tgs: {e2e_tgs:.1f} "
1465-
f"est_global_batch_tokens: {est_global_batch_tokens} "
14661465
f"eta: {eta_hms} "
14671466
)
14681467

@@ -1472,9 +1471,9 @@ def _log_step(
14721471
"time/step_time": round(step_time, 4),
14731472
"time/train_time": round(train_time, 4),
14741473
"time/eta_seconds": round(eta_seconds, 1),
1475-
"runtime_info/text_tokens": step_consumed_tokens,
1476-
"runtime_info/est_global_batch_tokens": est_global_batch_tokens,
1477-
"runtime_info/reduced_consumed_tokens": reduced_consumed_tokens,
1474+
"runtime_info/text_tokens": local_step_consumed_tokens,
1475+
"runtime_info/step_consumed_tokens": step_consumed_tokens,
1476+
"runtime_info/total_consumed_tokens": total_consumed_tokens,
14781477
"runtime_info/tgs": tgs,
14791478
"runtime_info/exp_tgs": exp_tgs,
14801479
"runtime_info/e2e_tgs": e2e_tgs,
@@ -1679,13 +1678,13 @@ def _load_checkpoint(self):
16791678
self._cur_epoch = train_state["cur_epoch"]
16801679

16811680
if load_checkpoint_cfg.load_dataset:
1682-
self._reduced_consumed_tokens = train_state.get("reduced_consumed_tokens", 0) # default 0 for BC
1681+
self._total_consumed_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC
16831682
self._train_time_offset = train_state["train_time_offset"]
1684-
# _reduced_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1685-
# 1) 如果加载 dataset,应该恢复_reduced_consumed_samples为checkpoint中的值
1686-
# 2) 如果不加载 dataset,应该保持_reduced_consumed_samples为初始值0,否则如果加载上旧dataloader的reduced_consumed_samples
1687-
# 会导致存储新dataloader时 reduced_consumed_samples 是不正确的值。
1688-
self._reduced_consumed_samples = train_state.get("reduced_consumed_samples", 0) # default 0 for BC
1683+
# _total_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1684+
# 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值
1685+
# 2) 如果不加载 dataset,应该保持_total_consumed_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples
1686+
# 会导致存储新dataloader时 total_consumed_samples 是不正确的值。
1687+
self._total_consumed_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC
16891688

16901689
dataloader_path = resume_from / self._SAVE_DATALOADER_DIR
16911690
self._resume_dataloader(dataloader_path)

0 commit comments

Comments
 (0)