diff --git a/application/backend/src/services/model_service.py b/application/backend/src/services/model_service.py index 09f6b62408..8617461b5f 100644 --- a/application/backend/src/services/model_service.py +++ b/application/backend/src/services/model_service.py @@ -25,7 +25,7 @@ from pydantic_models import Model, ModelList, PredictionLabel, PredictionResponse from pydantic_models.base import Pagination from pydantic_models.model import ExportParameters -from repositories import ModelRepository +from repositories import JobRepository, ModelRepository from repositories.binary_repo import ModelBinaryRepository, ModelExportBinaryRepository from services import ResourceNotFoundError from services.dataset_snapshot_service import DatasetSnapshotService @@ -112,10 +112,16 @@ async def delete_model(cls, project_id: UUID, model_id: UUID) -> None: ds_snapshot_id = model.dataset_snapshot_id await DatasetSnapshotService.delete_snapshot_if_unused(snapshot_id=ds_snapshot_id, project_id=project_id) + train_job_id = model.train_job_id + async with get_async_db_session_ctx() as session: repo = ModelRepository(session, project_id=project_id) await repo.delete_by_id(model_id) + if train_job_id: + job_repo = JobRepository(session) + await job_repo.delete_by_id(train_job_id) + @classmethod async def delete_project_models_db(cls, session: AsyncSession, project_id: UUID, commit: bool = False) -> None: """Delete all models associated with a project from the database.""" diff --git a/application/backend/tests/unit/services/test_model_service.py b/application/backend/tests/unit/services/test_model_service.py index 2a3aae08e5..ffc07ff1f1 100644 --- a/application/backend/tests/unit/services/test_model_service.py +++ b/application/backend/tests/unit/services/test_model_service.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 import numpy as np import openvino.properties.hint as ov_hints @@ -151,6 +152,31 @@ def test_delete_model(self, fxt_model_service, fxt_model_repository, fxt_model, ) mock_binary_repo.delete_model_folder.assert_called_once() + def test_delete_model_also_deletes_training_job( + self, fxt_model_service, fxt_model_repository, fxt_model, fxt_project + ): + """Test that deleting a model also deletes its associated training job.""" + fxt_model_repository.delete_by_id.return_value = None + fxt_model_repository.get_by_id.return_value = fxt_model + fxt_model.train_job_id = uuid4() + + with ( + patch("services.model_service.ModelRepository") as mock_repo_class, + patch("services.model_service.JobRepository") as mock_job_repo_class, + patch("services.model_service.DatasetSnapshotService") as mock_snapshot_service, + patch("services.model_service.ModelBinaryRepository") as mock_binary_repo_class, + ): + mock_repo_class.return_value = fxt_model_repository + mock_job_repo = MagicMock() + mock_job_repo.delete_by_id = AsyncMock() + mock_job_repo_class.return_value = mock_job_repo + mock_binary_repo_class.return_value.delete_model_folder = AsyncMock() + mock_snapshot_service.delete_snapshot_if_unused = AsyncMock() + + asyncio.run(fxt_model_service.delete_model(fxt_project.id, fxt_model.id)) + + mock_job_repo.delete_by_id.assert_called_once_with(fxt_model.train_job_id) + def test_load_inference_model_success(self, fxt_model_service, fxt_model, fxt_openvino_inferencer): """Test loading inference model successfully.""" with patch("services.model_service.ModelBinaryRepository") as mock_bin_repo_class: