@@ -514,9 +514,9 @@ def __init__(
514514 self ._debug = debug
515515 self ._seed = seed
516516
517- self ._reduced_consumed_tokens = 0
517+ self ._total_consumed_tokens = 0
518518 self ._exp_consumed_tokens = 0
519- self ._reduced_consumed_samples = 0
519+ self ._total_consumed_samples = 0
520520
521521 self ._train_time = 0
522522 self ._train_time_offset = 0
@@ -762,26 +762,24 @@ def fit(self):
762762 extra_info_dict = extra_info_updated .get ()
763763 loss_log .update (extra_info_dict )
764764
765- if "maxvio" in other_log :
766- loss_log ["maxvio" ] = other_log ["maxvio" ]
767765 loss_log ["efficient_attn_ratio" ] = other_log ["efficient_attn_ratio" ]
768766
769767 internal_metrics = self ._maybe_pop_model_internal_metrics (engine_input )
770768
771769 self ._cur_step += 1
772770
773771 reduced_step_consumed_tokens = self ._reduce_number_across_rank (step_consumed_tokens )
774- self ._reduced_consumed_tokens += reduced_step_consumed_tokens
775-
776- self ._exp_consumed_tokens += step_consumed_tokens
772+ self ._total_consumed_tokens += reduced_step_consumed_tokens
773+ self ._exp_consumed_tokens += reduced_step_consumed_tokens
777774 self ._train_time = time_after_train_step - train_begin
778775
779776 # TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
780777 self ._log_step (
781778 loss_log = loss_log ,
782- step_consumed_tokens = step_consumed_tokens ,
779+ local_step_consumed_tokens = step_consumed_tokens ,
780+ step_consumed_tokens = reduced_step_consumed_tokens ,
783781 exp_consumed_tokens = self ._exp_consumed_tokens ,
784- reduced_consumed_tokens = self ._reduced_consumed_tokens ,
782+ total_consumed_tokens = self ._total_consumed_tokens ,
785783 data_time = data_time ,
786784 step_time = step_time ,
787785 train_time = self ._train_time ,
@@ -1137,8 +1135,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11371135 {
11381136 "cur_step" : self .cur_step ,
11391137 "cur_epoch" : self ._cur_epoch ,
1140- "reduced_consumed_samples " : self ._reduced_consumed_samples ,
1141- "reduced_consumed_tokens " : self ._reduced_consumed_tokens ,
1138+ "total_consumed_samples " : self ._total_consumed_samples ,
1139+ "total_consumed_tokens " : self ._total_consumed_tokens ,
11421140 "train_time_offset" : self ._train_time + self ._train_time_offset ,
11431141 }
11441142 )
@@ -1150,8 +1148,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11501148 ckp_list .append (str (checkpoint_path ))
11511149 current_exp .cur_step = self .cur_step
11521150 current_exp .cur_epoch = self ._cur_epoch
1153- current_exp .consumed_samples = int (self ._reduced_consumed_samples )
1154- current_exp .consumed_tokens = int (self ._reduced_consumed_tokens )
1151+ current_exp .consumed_samples = int (self ._total_consumed_samples )
1152+ current_exp .consumed_tokens = int (self ._total_consumed_tokens )
11551153 current_exp .history [- 1 ]["end" ] = self .cur_step
11561154
11571155 # Delete checkpoints and update meta's checkpoint_list
@@ -1188,7 +1186,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11881186
11891187 def _save_dataloader (self , dataloader_path : Path | str ):
11901188 if self .rank == 0 :
1191- dataloader_state = self ._dataloader .get_state_dict (self ._reduced_consumed_samples )
1189+ dataloader_state = self ._dataloader .get_state_dict (self ._total_consumed_samples )
11921190 torch .save (dataloader_state , dataloader_path )
11931191
11941192 @property
@@ -1243,7 +1241,7 @@ def _data_iter(self):
12431241 data_iter = iter (self ._dataloader )
12441242 data = next (data_iter )
12451243
1246- self ._reduced_consumed_samples += self ._reduce_number_across_rank (len (data ))
1244+ self ._total_consumed_samples += self ._reduce_number_across_rank (len (data ))
12471245 yield data
12481246
12491247 def _get_checkpoint_path (self , epoch : int , step : int , is_snapshot : bool = False ) -> Path :
@@ -1413,10 +1411,11 @@ def _maybe_profiling(self):
14131411
14141412 def _log_step (
14151413 self ,
1416- loss_log : dict ,
1414+ loss_log : LossLog ,
1415+ local_step_consumed_tokens : int ,
14171416 step_consumed_tokens : int ,
14181417 exp_consumed_tokens : int ,
1419- reduced_consumed_tokens : int ,
1418+ total_consumed_tokens : int ,
14201419 data_time : float ,
14211420 step_time : float ,
14221421 train_time : float ,
@@ -1426,20 +1425,20 @@ def _log_step(
14261425 ):
14271426 """Log the training step information."""
14281427 e2e_train_time = train_time + train_time_offset
1429- tgs = step_consumed_tokens / step_time
1430- rank_consumed_tokens = reduced_consumed_tokens / self .world_size
1431- e2e_tgs = rank_consumed_tokens / e2e_train_time
1432- exp_tgs = exp_consumed_tokens / train_time
1428+ total_consumed_tokens_per_rank = total_consumed_tokens / self .world_size
1429+ exp_consumed_tokens_per_rank = exp_consumed_tokens / self .world_size
1430+
1431+ tgs = local_step_consumed_tokens / step_time
1432+ e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time
1433+ exp_tgs = exp_consumed_tokens_per_rank / train_time
14331434 lr = self ._lr_scheduler .get_last_lr ()[0 ]
14341435
14351436 remaining_steps = self .total_step - self .cur_step
1436- avg_tokens_per_step = rank_consumed_tokens / self .cur_step
1437+ avg_tokens_per_step = total_consumed_tokens_per_rank / self .cur_step
14371438 remaining_tokens = remaining_steps * avg_tokens_per_step
14381439 eta_seconds = remaining_tokens / (tgs + 1e-12 )
14391440 eta_hms = str (timedelta (seconds = int (eta_seconds )))
14401441
1441- est_global_batch_tokens = self .data_mesh ["dp" ].size () * step_consumed_tokens
1442-
14431442 loss_log_list = [f"{ k } : { v :.8f} " for k , v in loss_log .items ()]
14441443 loss_log_str = ", " .join (loss_log_list )
14451444
@@ -1453,16 +1452,16 @@ def _log_step(
14531452 self .logger .info (
14541453 f"Epoch { self ._cur_epoch } Step { self .cur_step } /{ self .total_step } "
14551454 f"data_time: { data_time :.4f} lr: { lr :.6e} time: { step_time :.4f} "
1456- f"text_tokens: { step_consumed_tokens } "
1457- f"reduced_consumed_tokens: { reduced_consumed_tokens } "
1455+ f"text_tokens: { local_step_consumed_tokens } "
1456+ f"step_consumed_tokens: { step_consumed_tokens } "
1457+ f"total_consumed_tokens: { total_consumed_tokens } "
14581458 f"{ loss_log_str } "
14591459 f"grad_norm: { grad_norm :.8f} "
14601460 f"max_memory: { max_memory / (1024 ** 3 ):.2f} GB "
14611461 f"reserved_memory: { reserved_memory / (1024 ** 3 ):.2f} GB "
14621462 f"tgs: { tgs :.1f} "
14631463 f"exp_tgs: { exp_tgs : .1f} "
14641464 f"e2e_tgs: { e2e_tgs :.1f} "
1465- f"est_global_batch_tokens: { est_global_batch_tokens } "
14661465 f"eta: { eta_hms } "
14671466 )
14681467
@@ -1472,9 +1471,9 @@ def _log_step(
14721471 "time/step_time" : round (step_time , 4 ),
14731472 "time/train_time" : round (train_time , 4 ),
14741473 "time/eta_seconds" : round (eta_seconds , 1 ),
1475- "runtime_info/text_tokens" : step_consumed_tokens ,
1476- "runtime_info/est_global_batch_tokens " : est_global_batch_tokens ,
1477- "runtime_info/reduced_consumed_tokens " : reduced_consumed_tokens ,
1474+ "runtime_info/text_tokens" : local_step_consumed_tokens ,
1475+ "runtime_info/step_consumed_tokens " : step_consumed_tokens ,
1476+ "runtime_info/total_consumed_tokens " : total_consumed_tokens ,
14781477 "runtime_info/tgs" : tgs ,
14791478 "runtime_info/exp_tgs" : exp_tgs ,
14801479 "runtime_info/e2e_tgs" : e2e_tgs ,
@@ -1679,13 +1678,13 @@ def _load_checkpoint(self):
16791678 self ._cur_epoch = train_state ["cur_epoch" ]
16801679
16811680 if load_checkpoint_cfg .load_dataset :
1682- self ._reduced_consumed_tokens = train_state .get ("reduced_consumed_tokens " , 0 ) # default 0 for BC
1681+ self ._total_consumed_tokens = train_state .get ("total_consumed_tokens " , 0 ) # default 0 for BC
16831682 self ._train_time_offset = train_state ["train_time_offset" ]
1684- # _reduced_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1685- # 1) 如果加载 dataset,应该恢复_reduced_consumed_samples为checkpoint中的值 。
1686- # 2) 如果不加载 dataset,应该保持_reduced_consumed_samples为初始值0,否则如果加载上旧dataloader的reduced_consumed_samples
1687- # 会导致存储新dataloader时 reduced_consumed_samples 是不正确的值。
1688- self ._reduced_consumed_samples = train_state .get ("reduced_consumed_samples " , 0 ) # default 0 for BC
1683+ # _total_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1684+ # 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值 。
1685+ # 2) 如果不加载 dataset,应该保持_total_consumed_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples
1686+ # 会导致存储新dataloader时 total_consumed_samples 是不正确的值。
1687+ self ._total_consumed_samples = train_state .get ("total_consumed_samples " , 0 ) # default 0 for BC
16891688
16901689 dataloader_path = resume_from / self ._SAVE_DATALOADER_DIR
16911690 self ._resume_dataloader (dataloader_path )
0 commit comments