diff --git a/application/backend/src/services/training_service.py b/application/backend/src/services/training_service.py index 0fe5b3bcc5..fb4e6f029d 100644 --- a/application/backend/src/services/training_service.py +++ b/application/backend/src/services/training_service.py @@ -230,6 +230,8 @@ def _train_model( if synchronization_parameters.cancel_training_event.is_set(): return None + synchronization_parameters.message = "exporting model" + export_path = engine.export( model=anomalib_model, export_type=export_format, diff --git a/application/backend/src/utils/callbacks.py b/application/backend/src/utils/callbacks.py index e3199f967b..bf712ad65a 100644 --- a/application/backend/src/utils/callbacks.py +++ b/application/backend/src/utils/callbacks.py @@ -108,6 +108,20 @@ def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, bat del pl_module, batch, batch_idx # unused self._check_cancel_training(trainer) + def on_train_batch_end( + self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int + ) -> None: + """Called when a training batch ends. Sends granular progress updates within each epoch.""" + del pl_module, outputs, batch # unused + if trainer.state.stage is not None and trainer.max_epochs is not None and trainer.max_epochs > 0: + total_batches = trainer.num_training_batches + if total_batches and total_batches > 0: + epoch_progress = trainer.current_epoch / trainer.max_epochs + batch_progress = (batch_idx + 1) / total_batches / trainer.max_epochs + progress = epoch_progress + batch_progress + self._send_progress(progress, trainer.state.stage.value) + self._check_cancel_training(trainer) + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when a training epoch ends.""" del pl_module # unused