Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion application/backend/src/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydantic_models import Model, ModelList, PredictionLabel, PredictionResponse
from pydantic_models.base import Pagination
from pydantic_models.model import ExportParameters
from repositories import ModelRepository
from repositories import JobRepository, ModelRepository
from repositories.binary_repo import ModelBinaryRepository, ModelExportBinaryRepository
from services import ResourceNotFoundError
from services.dataset_snapshot_service import DatasetSnapshotService
Expand Down Expand Up @@ -112,10 +112,16 @@ async def delete_model(cls, project_id: UUID, model_id: UUID) -> None:
ds_snapshot_id = model.dataset_snapshot_id
await DatasetSnapshotService.delete_snapshot_if_unused(snapshot_id=ds_snapshot_id, project_id=project_id)

train_job_id = model.train_job_id

async with get_async_db_session_ctx() as session:
repo = ModelRepository(session, project_id=project_id)
await repo.delete_by_id(model_id)

if train_job_id:
job_repo = JobRepository(session)
await job_repo.delete_by_id(train_job_id)

@classmethod
async def delete_project_models_db(cls, session: AsyncSession, project_id: UUID, commit: bool = False) -> None:
"""Delete all models associated with a project from the database."""
Expand Down
26 changes: 26 additions & 0 deletions application/backend/tests/unit/services/test_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4

import numpy as np
import openvino.properties.hint as ov_hints
Expand Down Expand Up @@ -151,6 +152,31 @@ def test_delete_model(self, fxt_model_service, fxt_model_repository, fxt_model,
)
mock_binary_repo.delete_model_folder.assert_called_once()

def test_delete_model_also_deletes_training_job(
self, fxt_model_service, fxt_model_repository, fxt_model, fxt_project
):
"""Test that deleting a model also deletes its associated training job."""
fxt_model_repository.delete_by_id.return_value = None
fxt_model_repository.get_by_id.return_value = fxt_model
fxt_model.train_job_id = uuid4()

with (
patch("services.model_service.ModelRepository") as mock_repo_class,
patch("services.model_service.JobRepository") as mock_job_repo_class,
patch("services.model_service.DatasetSnapshotService") as mock_snapshot_service,
patch("services.model_service.ModelBinaryRepository") as mock_binary_repo_class,
):
mock_repo_class.return_value = fxt_model_repository
mock_job_repo = MagicMock()
mock_job_repo.delete_by_id = AsyncMock()
mock_job_repo_class.return_value = mock_job_repo
mock_binary_repo_class.return_value.delete_model_folder = AsyncMock()
mock_snapshot_service.delete_snapshot_if_unused = AsyncMock()

asyncio.run(fxt_model_service.delete_model(fxt_project.id, fxt_model.id))

mock_job_repo.delete_by_id.assert_called_once_with(fxt_model.train_job_id)

def test_load_inference_model_success(self, fxt_model_service, fxt_model, fxt_openvino_inferencer):
"""Test loading inference model successfully."""
with patch("services.model_service.ModelBinaryRepository") as mock_bin_repo_class:
Expand Down