Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions application/backend/src/services/training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def _train_model(
if synchronization_parameters.cancel_training_event.is_set():
return None

synchronization_parameters.message = "exporting model"

export_path = engine.export(
model=anomalib_model,
export_type=export_format,
Expand Down
14 changes: 14 additions & 0 deletions application/backend/src/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, bat
del pl_module, batch, batch_idx # unused
self._check_cancel_training(trainer)

def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
) -> None:
"""Called when a training batch ends. Sends granular progress updates within each epoch."""
del pl_module, outputs, batch # unused
if trainer.state.stage is not None and trainer.max_epochs is not None and trainer.max_epochs > 0:
total_batches = trainer.num_training_batches
if total_batches and total_batches > 0:
epoch_progress = trainer.current_epoch / trainer.max_epochs
batch_progress = (batch_idx + 1) / total_batches / trainer.max_epochs
progress = epoch_progress + batch_progress
self._send_progress(progress, trainer.state.stage.value)
self._check_cancel_training(trainer)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when a training epoch ends."""
del pl_module # unused
Expand Down
Loading