Skip to content

Commit d3ab907

Browse files
🔧 chore(inspect): add early stopping to training job (#3217)
* Add early stopping Signed-off-by: Ashwin Vaidya <[email protected]> * Fix test Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 9b27ee2 commit d3ab907

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

application/backend/src/services/training_service.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from anomalib.deploy import ExportType
1111
from anomalib.engine import Engine
1212
from anomalib.loggers import AnomalibTensorBoardLogger
13+
from anomalib.metrics import AUROC, F1Score
14+
from anomalib.metrics.evaluator import Evaluator
1315
from anomalib.models import get_model
16+
from lightning.pytorch.callbacks import EarlyStopping
1417
from loguru import logger
1518

1619
from pydantic_models import Job, JobStatus, JobType, Model
@@ -203,7 +206,18 @@ def _train_model(
203206
)
204207

205208
# Initialize anomalib model and engine
206-
anomalib_model = get_model(model=model.name)
209+
anomalib_model = get_model(
210+
model=model.name,
211+
evaluator=Evaluator(
212+
val_metrics=[AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_", strict=False)],
213+
test_metrics=[
214+
AUROC(fields=["pred_score", "gt_label"], prefix="image_"),
215+
F1Score(fields=["pred_label", "gt_label"], prefix="image_"),
216+
AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_", strict=False),
217+
F1Score(fields=["pred_mask", "gt_mask"], prefix="pixel_", strict=False),
218+
],
219+
),
220+
)
207221

208222
trackio = TrackioLogger(project=str(model.project_id), name=model.name)
209223
tensorboard = AnomalibTensorBoardLogger(save_dir=global_log_config.tensorboard_log_path, name=name)
@@ -212,7 +226,10 @@ def _train_model(
212226
logger=[trackio, tensorboard],
213227
devices=[0], # Only single GPU training is supported for now
214228
max_epochs=max_epochs,
215-
callbacks=[GetiInspectProgressCallback(synchronization_parameters)],
229+
callbacks=[
230+
GetiInspectProgressCallback(synchronization_parameters),
231+
EarlyStopping(monitor="pixel_AUROC", mode="max", patience=5),
232+
],
216233
accelerator=training_device,
217234
)
218235

application/backend/tests/unit/services/test_training_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33
import asyncio
44
from typing import Any
5-
from unittest.mock import AsyncMock, MagicMock, patch
5+
from unittest.mock import ANY, AsyncMock, MagicMock, patch
66
from uuid import uuid4
77

88
import pytest
@@ -274,7 +274,7 @@ def test_train_model_success(
274274

275275
# Verify all components were called correctly
276276
fxt_mock_anomalib_components["folder_class"].assert_called_once()
277-
fxt_mock_anomalib_components["get_model"].assert_called_once_with(model=fxt_model.name)
277+
fxt_mock_anomalib_components["get_model"].assert_called_once_with(model=fxt_model.name, evaluator=ANY)
278278

279279
# Verify Engine was called with expected parameters
280280
fxt_mock_anomalib_components["engine_class"].assert_called_once()

0 commit comments

Comments
 (0)