From 14a449f5eaa6c2cfc834b48d6d230343ceaabf8b Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 14 Apr 2026 07:43:32 -0700 Subject: [PATCH 1/6] feat(jobs): schedule periodic stale-job reconcile + NATS consumer snapshots MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two Celery Beat periodic tasks so operators can see ML job health without waiting for a job to finish: 1. check_stale_jobs_task (every 15 min) — thin wrapper around the existing check_stale_jobs() that reconciles jobs stuck in running states past FAILED_CUTOFF_HOURS. Previously the function existed but was only reachable via the update_stale_jobs management command, so nothing ran it automatically. 2. log_running_async_job_stats (every 5 min) — for each incomplete async_api job, opens a TaskQueueManager with that job's logger and logs a delivered/ack_floor/num_pending/num_ack_pending/num_redelivered snapshot of the NATS consumer. Read-only; no status changes. Builds on the lifecycle-logging landed in #1222 so long-running jobs now get mid-flight visibility, not just create + cleanup snapshots. Both registered via migration 0020 so existing deployments pick them up on the next migrate without manual Beat configuration. Co-Authored-By: Claude --- ...0020_schedule_job_monitoring_beat_tasks.py | 55 +++++++++++ ami/jobs/tasks.py | 63 +++++++++++++ ami/jobs/tests/test_periodic_beat_tasks.py | 92 +++++++++++++++++++ ami/ml/orchestration/nats_queue.py | 24 +++-- ami/ml/orchestration/tests/test_nats_queue.py | 38 ++++++++ 5 files changed, 266 insertions(+), 6 deletions(-) create mode 100644 ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py create mode 100644 ami/jobs/tests/test_periodic_beat_tasks.py diff --git a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py new file mode 100644 index 000000000..c6996d8e9 --- /dev/null +++ b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py @@ -0,0 +1,55 @@ +from django.db import migrations + + +def create_periodic_tasks(apps, schema_editor): + from django_celery_beat.models import CrontabSchedule, PeriodicTask + + stale_schedule, _ = CrontabSchedule.objects.get_or_create( + minute="*/15", + hour="*", + day_of_week="*", + day_of_month="*", + month_of_year="*", + ) + PeriodicTask.objects.get_or_create( + name="jobs.check_stale_jobs", + defaults={ + "task": "ami.jobs.tasks.check_stale_jobs_task", + "crontab": stale_schedule, + "description": "Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS", + }, + ) + + stats_schedule, _ = CrontabSchedule.objects.get_or_create( + minute="*/5", + hour="*", + day_of_week="*", + day_of_month="*", + month_of_year="*", + ) + PeriodicTask.objects.get_or_create( + name="jobs.log_running_async_job_stats", + defaults={ + "task": "ami.jobs.tasks.log_running_async_job_stats", + "crontab": stats_schedule, + "description": "Log NATS consumer delivered/ack/pending stats for each running async_api job", + }, + ) + + +def delete_periodic_tasks(apps, schema_editor): + from django_celery_beat.models import PeriodicTask + + PeriodicTask.objects.filter( + name__in=["jobs.check_stale_jobs", "jobs.log_running_async_job_stats"], + ).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0019_job_dispatch_mode"), + ] + + operations = [ + migrations.RunPython(create_periodic_tasks, delete_periodic_tasks), + ] diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index ad3e18ca8..b017a6954 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -407,6 +407,69 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di return results +# Beat schedule is every 15 minutes for check_stale_jobs; expire queued copies +# that accumulate while a worker is unavailable so we don't process a backlog. +_STALE_JOB_BEAT_EXPIRES = 60 * 10 + + +@celery_app.task(soft_time_limit=300, time_limit=360, expires=_STALE_JOB_BEAT_EXPIRES) +def check_stale_jobs_task() -> dict: + """Celery Beat entry point for `check_stale_jobs`. + + Runs the existing stale-job reconciler on a schedule so jobs don't silently + sit in a running state for days when their Celery task is gone or the + worker crashed. Returns a summary dict for flower / task-result visibility. + """ + results = check_stale_jobs() + updated = sum(1 for r in results if r["action"] == "updated") + revoked = sum(1 for r in results if r["action"] == "revoked") + logger.info( + "check_stale_jobs_task finished: %d stale job(s), %d updated from Celery, %d revoked", + len(results), + updated, + revoked, + ) + return {"total": len(results), "updated": updated, "revoked": revoked} + + +# Expire faster than the stale-job task — this is observability, a skipped +# cycle is fine and we'd rather not pile up backlog of snapshot work. +_ASYNC_STATS_BEAT_EXPIRES = 60 * 4 + + +@celery_app.task(soft_time_limit=180, time_limit=240, expires=_ASYNC_STATS_BEAT_EXPIRES) +def log_running_async_job_stats() -> dict: + """Log a NATS consumer snapshot (delivered/ack/pending/redelivered) per running async_api job. + + Writes to the per-job logger so operators see counts in the job's UI log + without waiting for it to finish. Read-only: no status changes. + """ + from ami.jobs.models import Job, JobDispatchMode, JobState + + # Resolve each job's per-job logger synchronously — the property touches Django + # ORM via its JobLogHandler, which is only safe outside the event loop. + running_jobs = list( + Job.objects.filter( + status__in=JobState.running_states(), + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + ) + if not running_jobs: + return {"checked": 0} + + async def _snapshot_all(): + for job in running_jobs: + try: + async with TaskQueueManager(job_logger=job.logger) as manager: + await manager.log_consumer_stats_snapshot(job.pk) + except Exception: + # One job's NATS failure must not block snapshots for others. + logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + + async_to_sync(_snapshot_all)() + return {"checked": len(running_jobs)} + + def cleanup_async_job_if_needed(job) -> None: """ Clean up async resources (NATS/Redis) if this job uses them. diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py new file mode 100644 index 000000000..d229cce46 --- /dev/null +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -0,0 +1,92 @@ +from datetime import timedelta +from unittest.mock import AsyncMock, patch + +from django.test import TestCase +from django.utils import timezone + +from ami.jobs.models import Job, JobDispatchMode, JobState +from ami.jobs.tasks import check_stale_jobs_task, log_running_async_job_stats +from ami.main.models import Project + + +class CheckStaleJobsTaskTest(TestCase): + def setUp(self): + self.project = Project.objects.create(name="Beat schedule test project") + + def _create_stale_job(self, status=JobState.STARTED, hours_ago=100): + job = Job.objects.create(project=self.project, name="stale", status=status) + Job.objects.filter(pk=job.pk).update(updated_at=timezone.now() - timedelta(hours=hours_ago)) + job.refresh_from_db() + return job + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + def test_returns_summary_counts(self, _mock_cleanup): + self._create_stale_job() + self._create_stale_job() + result = check_stale_jobs_task() + self.assertEqual(result, {"total": 2, "updated": 0, "revoked": 2}) + + def test_no_stale_jobs_returns_zero_summary(self): + self._create_stale_job(hours_ago=1) # recent — not stale + self.assertEqual(check_stale_jobs_task(), {"total": 0, "updated": 0, "revoked": 0}) + + +class LogRunningAsyncJobStatsTest(TestCase): + def setUp(self): + self.project = Project.objects.create(name="Async snapshot test project") + + def _create_async_job(self, status=JobState.STARTED): + job = Job.objects.create(project=self.project, name=f"async {status}", status=status) + Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API) + job.refresh_from_db() + return job + + def test_no_running_jobs_short_circuits(self): + # A celery job with async dispatch but a final status should be skipped. + self._create_async_job(status=JobState.SUCCESS) + self.assertEqual(log_running_async_job_stats(), {"checked": 0}) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_snapshots_each_running_async_job(self, mock_manager_cls): + job_a = self._create_async_job() + job_b = self._create_async_job() + + instance = mock_manager_cls.return_value + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=False) + instance.log_consumer_stats_snapshot = AsyncMock() + + result = log_running_async_job_stats() + + self.assertEqual(result, {"checked": 2}) + snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list] + self.assertCountEqual(snapshots, [job_a.pk, job_b.pk]) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_one_job_failure_does_not_block_others(self, mock_manager_cls): + job_ok = self._create_async_job() + job_broken = self._create_async_job() + + instance = mock_manager_cls.return_value + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=False) + + calls = [] + + async def _snapshot(job_id): + calls.append(job_id) + if job_id == job_broken.pk: + raise RuntimeError("nats down for this one") + + instance.log_consumer_stats_snapshot = AsyncMock(side_effect=_snapshot) + + result = log_running_async_job_stats() + self.assertEqual(result, {"checked": 2}) + self.assertIn(job_ok.pk, calls) + self.assertIn(job_broken.pk, calls) + + def test_non_async_jobs_skipped(self): + job = Job.objects.create(project=self.project, name="sync job", status=JobState.STARTED) + # default dispatch_mode should not be ASYNC_API + self.assertNotEqual(job.dispatch_mode, JobDispatchMode.ASYNC_API) + self.assertEqual(log_running_async_job_stats(), {"checked": 0}) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 43d9d65e5..23eabe7d3 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -477,6 +477,18 @@ async def _log_final_consumer_stats(self, job_id: int) -> None: redelivered before the consumer vanished. Failures here must NOT block cleanup — if the consumer or stream is already gone, just skip it. """ + await self._log_consumer_stats(job_id, prefix="Finalizing NATS consumer", suffix="before deletion") + + async def log_consumer_stats_snapshot(self, job_id: int) -> None: + """Log a mid-flight snapshot of the consumer state for a running job. + + Used by the periodic `log_running_async_job_stats` beat task so operators + can see deliver/ack/pending counts without waiting for the job to finish. + Tolerant of missing stream/consumer like the cleanup-time variant. + """ + await self._log_consumer_stats(job_id, prefix="NATS consumer status") + + async def _log_consumer_stats(self, job_id: int, *, prefix: str, suffix: str = "") -> None: if self.js is None: return stream_name = self._get_stream_name(job_id) @@ -487,15 +499,15 @@ async def _log_final_consumer_stats(self, job_id: int) -> None: timeout=NATS_JETSTREAM_TIMEOUT, ) except Exception as e: - # Broad catch is intentional here (unlike _ensure_consumer): at - # cleanup time we tolerate any failure — stream gone, consumer - # already deleted, auth, timeout — so the delete calls below - # still get a chance to run. - logger.debug(f"Could not fetch consumer info for {consumer_name} before deletion: {e}") + # Broad catch is intentional: if the consumer or stream is gone we + # just skip — callers (cleanup, periodic snapshot) should never fail + # because we couldn't read stats. + logger.debug(f"Could not fetch consumer info for {consumer_name}: {e}") return + tail = f" {suffix}" if suffix else "" await self.log_async( logging.INFO, - f"Finalizing NATS consumer {consumer_name} before deletion ({self._format_consumer_stats(info)})", + f"{prefix} {consumer_name}{tail} ({self._format_consumer_stats(info)})", ) async def delete_consumer(self, job_id: int) -> bool: diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index d1d651450..9c35a4dae 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -493,6 +493,44 @@ async def test_publish_failure_surfaces_on_job_logger(self): f"expected publish failure on job_logger, got {messages}", ) + async def test_log_consumer_stats_snapshot_writes_current_stats(self): + """The periodic snapshot helper logs delivered/ack/pending WITHOUT + deleting the consumer — it's a mid-flight observability hook.""" + nc, js = self._create_mock_nats_connection() + js.consumer_info.return_value = self._make_consumer_info( + delivered=50, ack_floor=40, num_pending=10, num_ack_pending=10, num_redelivered=2 + ) + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.log_consumer_stats_snapshot(9) + + messages = [m for _, m in captured] + self.assertTrue( + any("NATS consumer status job-9-consumer" in m for m in messages), + f"expected snapshot line on job_logger, got {messages}", + ) + snapshot_line = next(m for m in messages if "NATS consumer status" in m) + for expected in ("delivered=50", "ack_floor=40", "num_redelivered=2"): + self.assertIn(expected, snapshot_line) + # Must NOT have triggered a delete — this is read-only observability. + js.delete_consumer.assert_not_called() + js.delete_stream.assert_not_called() + + async def test_log_consumer_stats_snapshot_tolerates_missing_consumer(self): + """If the consumer is already gone, the snapshot helper just no-ops.""" + nc, js = self._create_mock_nats_connection() + js.consumer_info.side_effect = nats.js.errors.NotFoundError() + + job_logger = self._make_captured_logger() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.log_consumer_stats_snapshot(99) # must not raise + async def test_no_job_logger_falls_back_to_module_logger_only(self): """When job_logger is None (e.g., module-level uses like advisory listener), lifecycle logs must still be emitted to the module logger From f91bb666209ceb7ea34d3a22ef776cafcaad70a1 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 14 Apr 2026 08:14:14 -0700 Subject: [PATCH 2/6] fix(jobs): address review comments on PR #1227 - log_running_async_job_stats: reuse a single TaskQueueManager per tick instead of opening N NATS connections. Cost is now O(1) in the number of running async jobs rather than O(n). Guarded outer connection setup so a NATS outage drops the tick cleanly instead of crashing the task. - check_stale_jobs_task: bump expires from 10 to 14 minutes so a delayed copy still runs within the 15-minute schedule instead of expiring before a worker picks it up under broker pressure. - migration 0020: use apps.get_model for django_celery_beat models and declare an explicit migration dependency so the data migration uses historical model state. Co-Authored-By: Claude --- ...0020_schedule_job_monitoring_beat_tasks.py | 6 ++-- ami/jobs/tasks.py | 32 ++++++++++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py index c6996d8e9..d266cfeb6 100644 --- a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py +++ b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py @@ -2,7 +2,8 @@ def create_periodic_tasks(apps, schema_editor): - from django_celery_beat.models import CrontabSchedule, PeriodicTask + CrontabSchedule = apps.get_model("django_celery_beat", "CrontabSchedule") + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") stale_schedule, _ = CrontabSchedule.objects.get_or_create( minute="*/15", @@ -38,7 +39,7 @@ def create_periodic_tasks(apps, schema_editor): def delete_periodic_tasks(apps, schema_editor): - from django_celery_beat.models import PeriodicTask + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") PeriodicTask.objects.filter( name__in=["jobs.check_stale_jobs", "jobs.log_running_async_job_stats"], @@ -48,6 +49,7 @@ def delete_periodic_tasks(apps, schema_editor): class Migration(migrations.Migration): dependencies = [ ("jobs", "0019_job_dispatch_mode"), + ("django_celery_beat", "0018_improve_crontab_helptext"), ] operations = [ diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index b017a6954..c3860044b 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -409,7 +409,10 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di # Beat schedule is every 15 minutes for check_stale_jobs; expire queued copies # that accumulate while a worker is unavailable so we don't process a backlog. -_STALE_JOB_BEAT_EXPIRES = 60 * 10 +# Kept just under the schedule interval so a backlog is dropped but a single +# delayed copy still runs. Going below the interval would risk every copy +# expiring before a worker picks it up under moderate broker pressure. +_STALE_JOB_BEAT_EXPIRES = 60 * 14 @celery_app.task(soft_time_limit=300, time_limit=360, expires=_STALE_JOB_BEAT_EXPIRES) @@ -458,13 +461,26 @@ def log_running_async_job_stats() -> dict: return {"checked": 0} async def _snapshot_all(): - for job in running_jobs: - try: - async with TaskQueueManager(job_logger=job.logger) as manager: - await manager.log_consumer_stats_snapshot(job.pk) - except Exception: - # One job's NATS failure must not block snapshots for others. - logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + # Reuse one TaskQueueManager (and thus one NATS connection) for the + # whole tick so cost stays O(1) in the number of running jobs. The + # manager's `job_logger` attribute is read fresh by `log_async` on + # every call, so swapping it per iteration routes lifecycle lines to + # the right job UI. `_setup_advisory_stream` (called in `__aenter__`) + # only logs via the module logger, so the initial logger choice does + # not leak into an unrelated job's log. + try: + async with TaskQueueManager(job_logger=running_jobs[0].logger) as manager: + for job in running_jobs: + try: + manager.job_logger = job.logger + await manager.log_consumer_stats_snapshot(job.pk) + except Exception: + # One job's NATS failure must not block snapshots for others. + logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + except Exception: + # Connection setup (or teardown) itself failed — log once and + # wait for the next beat tick rather than crashing the task. + logger.exception("Failed to open NATS connection for consumer snapshots") async_to_sync(_snapshot_all)() return {"checked": len(running_jobs)} From 081bdb9a93328f3b8b383a74020f3d7914bf57fa Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 14 Apr 2026 11:20:55 -0700 Subject: [PATCH 3/6] fix(jobs): fall back to per-job manager when shared path fails MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If opening the shared TaskQueueManager raises at setup (or teardown), try each job with its own fresh manager before giving up. Defends against a regression in the shared mutation pattern silently costing us snapshot visibility for every running async job on every tick. If NATS itself is down, the fallback loop will fail too and log once per job — same end-state as before the refactor, just more noisy. Co-Authored-By: Claude --- ami/jobs/tasks.py | 26 +++++++++++++++++----- ami/jobs/tests/test_periodic_beat_tasks.py | 22 ++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index c3860044b..9967c1286 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -460,10 +460,14 @@ def log_running_async_job_stats() -> dict: if not running_jobs: return {"checked": 0} + async def _snapshot_one(job) -> None: + async with TaskQueueManager(job_logger=job.logger) as manager: + await manager.log_consumer_stats_snapshot(job.pk) + async def _snapshot_all(): - # Reuse one TaskQueueManager (and thus one NATS connection) for the - # whole tick so cost stays O(1) in the number of running jobs. The - # manager's `job_logger` attribute is read fresh by `log_async` on + # Fast path: reuse one TaskQueueManager (and thus one NATS connection) + # for the whole tick so cost stays O(1) in the number of running jobs. + # The manager's `job_logger` attribute is read fresh by `log_async` on # every call, so swapping it per iteration routes lifecycle lines to # the right job UI. `_setup_advisory_stream` (called in `__aenter__`) # only logs via the module logger, so the initial logger choice does @@ -477,10 +481,20 @@ async def _snapshot_all(): except Exception: # One job's NATS failure must not block snapshots for others. logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + return except Exception: - # Connection setup (or teardown) itself failed — log once and - # wait for the next beat tick rather than crashing the task. - logger.exception("Failed to open NATS connection for consumer snapshots") + # Shared path failed at setup/teardown — could be NATS down (in which + # case per-job will fail identically and we'll log once per job) or + # a bug specific to reusing one manager (in which case per-job still + # works and we keep getting snapshots). Fall back rather than losing + # visibility for the whole tick. + logger.exception("Shared-connection snapshot failed; falling back to per-job connections") + + for job in running_jobs: + try: + await _snapshot_one(job) + except Exception: + logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) async_to_sync(_snapshot_all)() return {"checked": len(running_jobs)} diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py index d229cce46..6d0b7f3db 100644 --- a/ami/jobs/tests/test_periodic_beat_tasks.py +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -85,6 +85,28 @@ async def _snapshot(job_id): self.assertIn(job_ok.pk, calls) self.assertIn(job_broken.pk, calls) + @patch("ami.jobs.tasks.TaskQueueManager") + def test_shared_connection_failure_falls_back_to_per_job(self, mock_manager_cls): + job_a = self._create_async_job() + job_b = self._create_async_job() + + instance = mock_manager_cls.return_value + # First __aenter__ (shared path) blows up; subsequent ones (per-job + # fallback) succeed. Simulates a bug that only affects the shared path. + instance.__aenter__ = AsyncMock( + side_effect=[RuntimeError("shared path broken"), instance, instance], + ) + instance.__aexit__ = AsyncMock(return_value=False) + instance.log_consumer_stats_snapshot = AsyncMock() + + result = log_running_async_job_stats() + + self.assertEqual(result, {"checked": 2}) + # Shared attempt + one fresh manager per job = 3 __aenter__ calls total. + self.assertEqual(instance.__aenter__.await_count, 3) + snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list] + self.assertCountEqual(snapshots, [job_a.pk, job_b.pk]) + def test_non_async_jobs_skipped(self): job = Job.objects.create(project=self.project, name="sync job", status=JobState.STARTED) # default dispatch_mode should not be ASYNC_API From 08036ca65778538a10d643a5ccb5e86da78769e3 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 14 Apr 2026 12:34:35 -0700 Subject: [PATCH 4/6] refactor(jobs): rename beat task to jobs_health_check umbrella MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The periodic task is renamed from `check_stale_jobs_task` to `jobs_health_check` so new job-health checks can share the 15-minute cadence and `expires` guarantees without a new beat entry. Current body runs a single `_run_stale_jobs_check()`; future checks plug in alongside it and return the same `{checked, fixed, unfixable}` shape. Names chosen to parallel #1188's integrity-check pattern: - Beat task = noun phrase (`jobs_health_check`) — reads well in flower - (Future) management command = verb (`manage.py check_jobs`) Migration 0020 updated in place since it hasn't shipped yet. Co-Authored-By: Claude --- ...0020_schedule_job_monitoring_beat_tasks.py | 6 ++-- ami/jobs/tasks.py | 32 +++++++++++++------ ami/jobs/tests/test_periodic_beat_tasks.py | 15 +++++---- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py index d266cfeb6..d951e6694 100644 --- a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py +++ b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py @@ -13,11 +13,11 @@ def create_periodic_tasks(apps, schema_editor): month_of_year="*", ) PeriodicTask.objects.get_or_create( - name="jobs.check_stale_jobs", + name="jobs.health_check", defaults={ - "task": "ami.jobs.tasks.check_stale_jobs_task", + "task": "ami.jobs.tasks.jobs_health_check", "crontab": stale_schedule, - "description": "Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS", + "description": "Umbrella job-health checks (stale job reconciler, future integrity checks)", }, ) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 9967c1286..ec6903aad 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -407,32 +407,44 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di return results -# Beat schedule is every 15 minutes for check_stale_jobs; expire queued copies +# Beat schedule is every 15 minutes for jobs_health_check; expire queued copies # that accumulate while a worker is unavailable so we don't process a backlog. # Kept just under the schedule interval so a backlog is dropped but a single # delayed copy still runs. Going below the interval would risk every copy # expiring before a worker picks it up under moderate broker pressure. -_STALE_JOB_BEAT_EXPIRES = 60 * 14 +_JOBS_HEALTH_BEAT_EXPIRES = 60 * 14 -@celery_app.task(soft_time_limit=300, time_limit=360, expires=_STALE_JOB_BEAT_EXPIRES) -def check_stale_jobs_task() -> dict: - """Celery Beat entry point for `check_stale_jobs`. +def _run_stale_jobs_check() -> dict: + """Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS. - Runs the existing stale-job reconciler on a schedule so jobs don't silently - sit in a running state for days when their Celery task is gone or the - worker crashed. Returns a summary dict for flower / task-result visibility. + Returns `{checked, fixed, unfixable}` so it composes with other health + checks that return the same shape. """ results = check_stale_jobs() updated = sum(1 for r in results if r["action"] == "updated") revoked = sum(1 for r in results if r["action"] == "revoked") logger.info( - "check_stale_jobs_task finished: %d stale job(s), %d updated from Celery, %d revoked", + "stale_jobs check: %d stale job(s), %d updated from Celery, %d revoked", len(results), updated, revoked, ) - return {"total": len(results), "updated": updated, "revoked": revoked} + return {"checked": len(results), "fixed": updated + revoked, "unfixable": 0} + + +@celery_app.task(soft_time_limit=300, time_limit=360, expires=_JOBS_HEALTH_BEAT_EXPIRES) +def jobs_health_check() -> dict: + """Umbrella beat task for periodic job-health checks. + + Each sub-check returns `{checked, fixed, unfixable}`. Add new checks here + — they share the 15-minute cadence and `expires` guarantees. Keep checks + cheap (DB scans, light reconciliation); long-running repair work belongs + in its own task. + """ + return { + "stale_jobs": _run_stale_jobs_check(), + } # Expire faster than the stale-job task — this is observability, a skipped diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py index 6d0b7f3db..31035bf28 100644 --- a/ami/jobs/tests/test_periodic_beat_tasks.py +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -5,11 +5,11 @@ from django.utils import timezone from ami.jobs.models import Job, JobDispatchMode, JobState -from ami.jobs.tasks import check_stale_jobs_task, log_running_async_job_stats +from ami.jobs.tasks import jobs_health_check, log_running_async_job_stats from ami.main.models import Project -class CheckStaleJobsTaskTest(TestCase): +class JobsHealthCheckTest(TestCase): def setUp(self): self.project = Project.objects.create(name="Beat schedule test project") @@ -20,15 +20,18 @@ def _create_stale_job(self, status=JobState.STARTED, hours_ago=100): return job @patch("ami.jobs.tasks.cleanup_async_job_if_needed") - def test_returns_summary_counts(self, _mock_cleanup): + def test_returns_nested_summary_counts(self, _mock_cleanup): self._create_stale_job() self._create_stale_job() - result = check_stale_jobs_task() - self.assertEqual(result, {"total": 2, "updated": 0, "revoked": 2}) + result = jobs_health_check() + self.assertEqual(result, {"stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0}}) def test_no_stale_jobs_returns_zero_summary(self): self._create_stale_job(hours_ago=1) # recent — not stale - self.assertEqual(check_stale_jobs_task(), {"total": 0, "updated": 0, "revoked": 0}) + self.assertEqual( + jobs_health_check(), + {"stale_jobs": {"checked": 0, "fixed": 0, "unfixable": 0}}, + ) class LogRunningAsyncJobStatsTest(TestCase): From 6dc7e6ec7a8a13c8a9ca01ac1e49ee8c949a1b37 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 14 Apr 2026 14:14:37 -0700 Subject: [PATCH 5/6] refactor(jobs): fold snapshot task into umbrella, adopt IntegrityCheckResult MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses the takeaway-review findings on PR #1227: - Fix reverse migration: delete the row we create. The forward step registers `jobs.health_check` but the old `delete_periodic_tasks` still referenced the pre-rename `jobs.check_stale_jobs`, leaving a stranded row on rollback. - Collapse `log_running_async_job_stats` into a sub-check of the umbrella. Drops the second `PeriodicTask` and the shared-connection fallback path — on a 15-minute cadence there's no reason to keep two beat tasks or defend against a per-manager bug, and a quietly hung job now gets a snapshot in the same tick the reconciler will decide whether to revoke it. - Adopt `IntegrityCheckResult` as the shared sub-check shape, housed in a new `ami.main.checks.schemas` module so PR #1188 can re-target its import without a merge-order dance. Wrap the two sub-checks in a `JobsHealthCheckResult` parent dataclass; the umbrella returns `dataclasses.asdict(...)` for celery's JSON backend. Observation checks leave `fixed=0` and count per-item errors in `unfixable`. Tests collapse to one `JobsHealthCheckTest` class covering both sub-checks (reconcile + snapshot) and the edge cases that matter: per-job snapshot failure, shared-connection setup failure, non-async jobs skipped, idle deployment returns all zeros. Co-Authored-By: Claude --- ...0020_schedule_job_monitoring_beat_tasks.py | 30 +--- ami/jobs/tasks.py | 155 +++++++++--------- ami/jobs/tests/test_periodic_beat_tasks.py | 122 +++++++------- ami/main/checks/__init__.py | 12 ++ ami/main/checks/schemas.py | 27 +++ ami/ml/orchestration/nats_queue.py | 7 +- 6 files changed, 195 insertions(+), 158 deletions(-) create mode 100644 ami/main/checks/__init__.py create mode 100644 ami/main/checks/schemas.py diff --git a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py index d951e6694..6f1ee3ef3 100644 --- a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py +++ b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py @@ -5,7 +5,7 @@ def create_periodic_tasks(apps, schema_editor): CrontabSchedule = apps.get_model("django_celery_beat", "CrontabSchedule") PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") - stale_schedule, _ = CrontabSchedule.objects.get_or_create( + schedule, _ = CrontabSchedule.objects.get_or_create( minute="*/15", hour="*", day_of_week="*", @@ -16,34 +16,18 @@ def create_periodic_tasks(apps, schema_editor): name="jobs.health_check", defaults={ "task": "ami.jobs.tasks.jobs_health_check", - "crontab": stale_schedule, - "description": "Umbrella job-health checks (stale job reconciler, future integrity checks)", - }, - ) - - stats_schedule, _ = CrontabSchedule.objects.get_or_create( - minute="*/5", - hour="*", - day_of_week="*", - day_of_month="*", - month_of_year="*", - ) - PeriodicTask.objects.get_or_create( - name="jobs.log_running_async_job_stats", - defaults={ - "task": "ami.jobs.tasks.log_running_async_job_stats", - "crontab": stats_schedule, - "description": "Log NATS consumer delivered/ack/pending stats for each running async_api job", + "crontab": schedule, + "description": ( + "Umbrella job-health checks: stale-job reconciler plus a NATS " + "consumer snapshot for each running async_api job." + ), }, ) def delete_periodic_tasks(apps, schema_editor): PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") - - PeriodicTask.objects.filter( - name__in=["jobs.check_stale_jobs", "jobs.log_running_async_job_stats"], - ).delete() + PeriodicTask.objects.filter(name="jobs.health_check").delete() class Migration(migrations.Migration): diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index ec6903aad..1371a6078 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import functools import logging @@ -9,6 +10,7 @@ from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction +from ami.main.checks.schemas import IntegrityCheckResult from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse @@ -407,20 +409,31 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di return results -# Beat schedule is every 15 minutes for jobs_health_check; expire queued copies -# that accumulate while a worker is unavailable so we don't process a backlog. -# Kept just under the schedule interval so a backlog is dropped but a single -# delayed copy still runs. Going below the interval would risk every copy -# expiring before a worker picks it up under moderate broker pressure. +# Expire queued copies that accumulate while a worker is unavailable so we +# don't process a backlog when a worker reconnects. Kept below the 15-minute +# schedule interval so a backlog is dropped but a single delayed copy still +# runs. Going well below the interval would risk every copy expiring before +# a worker picks it up under moderate broker pressure — change this in lock- +# step with the crontab in migration 0020. _JOBS_HEALTH_BEAT_EXPIRES = 60 * 14 -def _run_stale_jobs_check() -> dict: - """Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS. +@dataclasses.dataclass +class JobsHealthCheckResult: + """Nested result of one :func:`jobs_health_check` tick. - Returns `{checked, fixed, unfixable}` so it composes with other health - checks that return the same shape. + Each field is the summary for one sub-check and uses the shared + :class:`IntegrityCheckResult` shape so operators see a uniform + ``checked / fixed / unfixable`` triple regardless of which check ran. + Add a new field here when adding a sub-check to the umbrella. """ + + stale_jobs: IntegrityCheckResult + running_job_snapshots: IntegrityCheckResult + + +def _run_stale_jobs_check() -> IntegrityCheckResult: + """Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS.""" results = check_stale_jobs() updated = sum(1 for r in results if r["action"] == "updated") revoked = sum(1 for r in results if r["action"] == "revoked") @@ -430,39 +443,20 @@ def _run_stale_jobs_check() -> dict: updated, revoked, ) - return {"checked": len(results), "fixed": updated + revoked, "unfixable": 0} - - -@celery_app.task(soft_time_limit=300, time_limit=360, expires=_JOBS_HEALTH_BEAT_EXPIRES) -def jobs_health_check() -> dict: - """Umbrella beat task for periodic job-health checks. - - Each sub-check returns `{checked, fixed, unfixable}`. Add new checks here - — they share the 15-minute cadence and `expires` guarantees. Keep checks - cheap (DB scans, light reconciliation); long-running repair work belongs - in its own task. - """ - return { - "stale_jobs": _run_stale_jobs_check(), - } - - -# Expire faster than the stale-job task — this is observability, a skipped -# cycle is fine and we'd rather not pile up backlog of snapshot work. -_ASYNC_STATS_BEAT_EXPIRES = 60 * 4 + return IntegrityCheckResult(checked=len(results), fixed=updated + revoked, unfixable=0) -@celery_app.task(soft_time_limit=180, time_limit=240, expires=_ASYNC_STATS_BEAT_EXPIRES) -def log_running_async_job_stats() -> dict: - """Log a NATS consumer snapshot (delivered/ack/pending/redelivered) per running async_api job. +def _run_running_job_snapshot_check() -> IntegrityCheckResult: + """Log a NATS consumer snapshot for each running async_api job. - Writes to the per-job logger so operators see counts in the job's UI log - without waiting for it to finish. Read-only: no status changes. + Observation-only: ``fixed`` stays 0 because no state is altered. Jobs + that error during snapshot are counted in ``unfixable`` — a persistently + stuck job will be picked up on the next tick by ``_run_stale_jobs_check``. """ from ami.jobs.models import Job, JobDispatchMode, JobState - # Resolve each job's per-job logger synchronously — the property touches Django - # ORM via its JobLogHandler, which is only safe outside the event loop. + # Resolve each job's per-job logger synchronously — the property touches + # Django ORM via JobLogHandler, which is only safe outside the event loop. running_jobs = list( Job.objects.filter( status__in=JobState.running_states(), @@ -470,46 +464,57 @@ def log_running_async_job_stats() -> dict: ) ) if not running_jobs: - return {"checked": 0} - - async def _snapshot_one(job) -> None: - async with TaskQueueManager(job_logger=job.logger) as manager: - await manager.log_consumer_stats_snapshot(job.pk) - - async def _snapshot_all(): - # Fast path: reuse one TaskQueueManager (and thus one NATS connection) - # for the whole tick so cost stays O(1) in the number of running jobs. - # The manager's `job_logger` attribute is read fresh by `log_async` on - # every call, so swapping it per iteration routes lifecycle lines to - # the right job UI. `_setup_advisory_stream` (called in `__aenter__`) - # only logs via the module logger, so the initial logger choice does - # not leak into an unrelated job's log. - try: - async with TaskQueueManager(job_logger=running_jobs[0].logger) as manager: - for job in running_jobs: - try: - manager.job_logger = job.logger - await manager.log_consumer_stats_snapshot(job.pk) - except Exception: - # One job's NATS failure must not block snapshots for others. - logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) - return - except Exception: - # Shared path failed at setup/teardown — could be NATS down (in which - # case per-job will fail identically and we'll log once per job) or - # a bug specific to reusing one manager (in which case per-job still - # works and we keep getting snapshots). Fall back rather than losing - # visibility for the whole tick. - logger.exception("Shared-connection snapshot failed; falling back to per-job connections") - - for job in running_jobs: - try: - await _snapshot_one(job) - except Exception: - logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + return IntegrityCheckResult() + + errors = 0 - async_to_sync(_snapshot_all)() - return {"checked": len(running_jobs)} + async def _snapshot_all() -> None: + nonlocal errors + # One NATS connection per tick — on a 15-min cadence a per-job fallback + # is not worth the code. If the shared connection fails to set up, we + # skip this tick's snapshots and try fresh on the next one. + async with TaskQueueManager(job_logger=running_jobs[0].logger) as manager: + for job in running_jobs: + try: + # `log_async` reads `job_logger` fresh each call, so + # swapping per iteration routes lifecycle lines to the + # right job's UI log. + manager.job_logger = job.logger + await manager.log_consumer_stats_snapshot(job.pk) + except Exception: + errors += 1 + logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + + try: + async_to_sync(_snapshot_all)() + except Exception: + logger.exception("Shared-connection snapshot setup failed; skipping this tick") + errors = len(running_jobs) + + logger.info( + "running_job_snapshots check: %d running async job(s), %d error(s)", + len(running_jobs), + errors, + ) + return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors) + + +@celery_app.task(soft_time_limit=300, time_limit=360, expires=_JOBS_HEALTH_BEAT_EXPIRES) +def jobs_health_check() -> dict: + """Umbrella beat task for periodic job-health checks. + + Composes reconciliation (stale jobs) with observation (NATS consumer + snapshots for running async jobs) so both land in the same 15-minute + tick — a quietly hung async job gets a snapshot entry right before the + reconciler decides whether to revoke it. Returns the serialized form of + :class:`JobsHealthCheckResult` so celery's default JSON backend can store + it; add new sub-checks by extending that dataclass and calling them here. + """ + result = JobsHealthCheckResult( + stale_jobs=_run_stale_jobs_check(), + running_job_snapshots=_run_running_job_snapshot_check(), + ) + return dataclasses.asdict(result) def cleanup_async_job_if_needed(job) -> None: diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py index 31035bf28..9f6fd9bf2 100644 --- a/ami/jobs/tests/test_periodic_beat_tasks.py +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -5,10 +5,16 @@ from django.utils import timezone from ami.jobs.models import Job, JobDispatchMode, JobState -from ami.jobs.tasks import jobs_health_check, log_running_async_job_stats +from ami.jobs.tasks import jobs_health_check from ami.main.models import Project +def _empty_check_dict() -> dict: + return {"checked": 0, "fixed": 0, "unfixable": 0} + + +@patch("ami.jobs.tasks.cleanup_async_job_if_needed") +@patch("ami.jobs.tasks.TaskQueueManager") class JobsHealthCheckTest(TestCase): def setUp(self): self.project = Project.objects.create(name="Beat schedule test project") @@ -19,60 +25,62 @@ def _create_stale_job(self, status=JobState.STARTED, hours_ago=100): job.refresh_from_db() return job - @patch("ami.jobs.tasks.cleanup_async_job_if_needed") - def test_returns_nested_summary_counts(self, _mock_cleanup): + def _create_async_job(self, status=JobState.STARTED): + job = Job.objects.create(project=self.project, name=f"async {status}", status=status) + Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API) + job.refresh_from_db() + return job + + def _stub_manager(self, mock_manager_cls) -> AsyncMock: + instance = mock_manager_cls.return_value + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=False) + instance.log_consumer_stats_snapshot = AsyncMock() + return instance + + def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup): self._create_stale_job() self._create_stale_job() + self._stub_manager(mock_manager_cls) + result = jobs_health_check() - self.assertEqual(result, {"stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0}}) - def test_no_stale_jobs_returns_zero_summary(self): - self._create_stale_job(hours_ago=1) # recent — not stale self.assertEqual( - jobs_health_check(), - {"stale_jobs": {"checked": 0, "fixed": 0, "unfixable": 0}}, + result, + { + "stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0}, + "running_job_snapshots": _empty_check_dict(), + }, ) + def test_idle_deployment_returns_all_zeros(self, mock_manager_cls, _mock_cleanup): + # No stale jobs, no running async jobs. + self._create_stale_job(hours_ago=1) # recent — not stale + self._stub_manager(mock_manager_cls) -class LogRunningAsyncJobStatsTest(TestCase): - def setUp(self): - self.project = Project.objects.create(name="Async snapshot test project") - - def _create_async_job(self, status=JobState.STARTED): - job = Job.objects.create(project=self.project, name=f"async {status}", status=status) - Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API) - job.refresh_from_db() - return job - - def test_no_running_jobs_short_circuits(self): - # A celery job with async dispatch but a final status should be skipped. - self._create_async_job(status=JobState.SUCCESS) - self.assertEqual(log_running_async_job_stats(), {"checked": 0}) + self.assertEqual( + jobs_health_check(), + { + "stale_jobs": _empty_check_dict(), + "running_job_snapshots": _empty_check_dict(), + }, + ) - @patch("ami.jobs.tasks.TaskQueueManager") - def test_snapshots_each_running_async_job(self, mock_manager_cls): + def test_snapshots_each_running_async_job(self, mock_manager_cls, _mock_cleanup): job_a = self._create_async_job() job_b = self._create_async_job() + instance = self._stub_manager(mock_manager_cls) - instance = mock_manager_cls.return_value - instance.__aenter__ = AsyncMock(return_value=instance) - instance.__aexit__ = AsyncMock(return_value=False) - instance.log_consumer_stats_snapshot = AsyncMock() - - result = log_running_async_job_stats() + result = jobs_health_check() - self.assertEqual(result, {"checked": 2}) + self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 0}) snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list] self.assertCountEqual(snapshots, [job_a.pk, job_b.pk]) - @patch("ami.jobs.tasks.TaskQueueManager") - def test_one_job_failure_does_not_block_others(self, mock_manager_cls): + def test_one_job_snapshot_failure_counts_as_unfixable(self, mock_manager_cls, _mock_cleanup): job_ok = self._create_async_job() job_broken = self._create_async_job() - - instance = mock_manager_cls.return_value - instance.__aenter__ = AsyncMock(return_value=instance) - instance.__aexit__ = AsyncMock(return_value=False) + instance = self._stub_manager(mock_manager_cls) calls = [] @@ -83,35 +91,35 @@ async def _snapshot(job_id): instance.log_consumer_stats_snapshot = AsyncMock(side_effect=_snapshot) - result = log_running_async_job_stats() - self.assertEqual(result, {"checked": 2}) + result = jobs_health_check() + + # Both jobs were attempted; only the broken one failed. + self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 1}) self.assertIn(job_ok.pk, calls) self.assertIn(job_broken.pk, calls) - @patch("ami.jobs.tasks.TaskQueueManager") - def test_shared_connection_failure_falls_back_to_per_job(self, mock_manager_cls): - job_a = self._create_async_job() - job_b = self._create_async_job() + def test_shared_connection_setup_failure_marks_all_unfixable(self, mock_manager_cls, _mock_cleanup): + self._create_async_job() + self._create_async_job() instance = mock_manager_cls.return_value - # First __aenter__ (shared path) blows up; subsequent ones (per-job - # fallback) succeed. Simulates a bug that only affects the shared path. - instance.__aenter__ = AsyncMock( - side_effect=[RuntimeError("shared path broken"), instance, instance], - ) + instance.__aenter__ = AsyncMock(side_effect=RuntimeError("nats down")) instance.__aexit__ = AsyncMock(return_value=False) instance.log_consumer_stats_snapshot = AsyncMock() - result = log_running_async_job_stats() + result = jobs_health_check() - self.assertEqual(result, {"checked": 2}) - # Shared attempt + one fresh manager per job = 3 __aenter__ calls total. - self.assertEqual(instance.__aenter__.await_count, 3) - snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list] - self.assertCountEqual(snapshots, [job_a.pk, job_b.pk]) + # All running jobs are counted as unfixable for this tick; no + # snapshots ran and the shared-connection error was swallowed. + self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 2}) + instance.log_consumer_stats_snapshot.assert_not_awaited() - def test_non_async_jobs_skipped(self): + def test_non_async_running_jobs_are_ignored_by_snapshot_check(self, mock_manager_cls, _mock_cleanup): job = Job.objects.create(project=self.project, name="sync job", status=JobState.STARTED) - # default dispatch_mode should not be ASYNC_API self.assertNotEqual(job.dispatch_mode, JobDispatchMode.ASYNC_API) - self.assertEqual(log_running_async_job_stats(), {"checked": 0}) + instance = self._stub_manager(mock_manager_cls) + + result = jobs_health_check() + + self.assertEqual(result["running_job_snapshots"], _empty_check_dict()) + instance.log_consumer_stats_snapshot.assert_not_awaited() diff --git a/ami/main/checks/__init__.py b/ami/main/checks/__init__.py new file mode 100644 index 000000000..506d49ffe --- /dev/null +++ b/ami/main/checks/__init__.py @@ -0,0 +1,12 @@ +"""Integrity and health check primitives shared across apps. + +Sub-modules in this package (added by per-domain check PRs such as +``ami.main.checks.occurrences`` in #1188) define ``get_*`` and +``reconcile_*`` function pairs. The shared result schema lives in +:mod:`ami.main.checks.schemas` so reconciliation and observation checks +across apps return the same shape. +""" + +from ami.main.checks.schemas import IntegrityCheckResult + +__all__ = ["IntegrityCheckResult"] diff --git a/ami/main/checks/schemas.py b/ami/main/checks/schemas.py new file mode 100644 index 000000000..75347b2aa --- /dev/null +++ b/ami/main/checks/schemas.py @@ -0,0 +1,27 @@ +"""Shared result schemas for integrity and health checks. + +A check is any function that inspects some slice of state and returns an +:class:`IntegrityCheckResult`. Reconciliation checks populate ``fixed`` +with the number of rows actually mutated; observation checks (e.g. +logging a snapshot) keep ``fixed`` at 0 and use ``unfixable`` to count +items the check could not complete for. +""" + +import dataclasses + + +@dataclasses.dataclass +class IntegrityCheckResult: + """Summary of a single integrity or health check pass. + + Attributes: + checked: Rows / items the check inspected this pass. + fixed: Rows the check mutated to a correct state. Observation-only + checks must leave this at 0 — ``fixed`` means state was altered. + unfixable: Rows the check inspected but could not repair or observe + (for observation checks this counts errors per item). + """ + + checked: int = 0 + fixed: int = 0 + unfixable: int = 0 diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 23eabe7d3..767298af4 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -482,9 +482,10 @@ async def _log_final_consumer_stats(self, job_id: int) -> None: async def log_consumer_stats_snapshot(self, job_id: int) -> None: """Log a mid-flight snapshot of the consumer state for a running job. - Used by the periodic `log_running_async_job_stats` beat task so operators - can see deliver/ack/pending counts without waiting for the job to finish. - Tolerant of missing stream/consumer like the cleanup-time variant. + Called by the ``running_job_snapshots`` sub-check of the periodic + ``jobs_health_check`` beat task so operators can see deliver/ack/pending + counts without waiting for the job to finish. Tolerant of missing + stream/consumer like the cleanup-time variant. """ await self._log_consumer_stats(job_id, prefix="NATS consumer status") From 6726eee55e32e409d1c933935958d77fa01d10e1 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 14 Apr 2026 14:30:47 -0700 Subject: [PATCH 6/6] fix(jobs): isolate sub-checks + pre-resolve loggers off event loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review findings from the re-review on 6dc7e6ec: - The umbrella had no guard around sub-check calls, so a DB hiccup in ``_run_stale_jobs_check`` would kill the snapshot check and fail the whole task. Wrap each call in ``_safe_run_sub_check``, which catches, logs, and returns ``IntegrityCheckResult(unfixable=1)`` as a sentinel — operators watching the task result in Flower see the sub-check failed rather than reading all-zeros and assuming all-clear. - ``Job.logger`` attaches a ``JobLogHandler`` on first access which touches the ORM; the file's own docstring says resolve outside the event loop, but two accesses were inside the coroutine. Pre-resolve into a list of ``(job, job_logger)`` tuples before entering ``async_to_sync``. - Escalate the ``running_job_snapshots`` summary log to WARNING when ``errors > 0`` so persistent NATS unavailability is distinguishable from a quiet tick in aggregated logs. - Document that the outer shared-connection except overwrites per-iteration error counts on the rare ``__aexit__`` teardown path. New tests: - ``test_sub_check_exception_does_not_block_the_other`` — patches the snapshot sub-check to raise; stale-jobs still reports correctly and snapshots come back as the ``unfixable=1`` sentinel. - ``test_stale_jobs_fixed_counts_celery_updated_and_revoked_paths`` — one stale job with a terminal Celery ``task_id``, one without; both branches of ``fixed`` counted so a future refactor dropping one branch breaks the test. - Explicit ``fixed == 0`` assertion in the snapshot test locks the observation-only contract. Co-Authored-By: Claude --- ami/jobs/tasks.py | 43 ++++++++++++++----- ami/jobs/tests/test_periodic_beat_tasks.py | 50 ++++++++++++++++++++++ 2 files changed, 83 insertions(+), 10 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 1371a6078..d3c8e9070 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -455,8 +455,6 @@ def _run_running_job_snapshot_check() -> IntegrityCheckResult: """ from ami.jobs.models import Job, JobDispatchMode, JobState - # Resolve each job's per-job logger synchronously — the property touches - # Django ORM via JobLogHandler, which is only safe outside the event loop. running_jobs = list( Job.objects.filter( status__in=JobState.running_states(), @@ -466,6 +464,11 @@ def _run_running_job_snapshot_check() -> IntegrityCheckResult: if not running_jobs: return IntegrityCheckResult() + # Resolve each job's per-job logger synchronously before entering the + # event loop — ``Job.logger`` attaches a ``JobLogHandler`` on first access + # which touches the Django ORM, so it is only safe to call from a sync + # context. + job_loggers = [(job, job.logger) for job in running_jobs] errors = 0 async def _snapshot_all() -> None: @@ -473,13 +476,13 @@ async def _snapshot_all() -> None: # One NATS connection per tick — on a 15-min cadence a per-job fallback # is not worth the code. If the shared connection fails to set up, we # skip this tick's snapshots and try fresh on the next one. - async with TaskQueueManager(job_logger=running_jobs[0].logger) as manager: - for job in running_jobs: + async with TaskQueueManager(job_logger=job_loggers[0][1]) as manager: + for job, job_logger in job_loggers: try: - # `log_async` reads `job_logger` fresh each call, so + # ``log_async`` reads ``job_logger`` fresh each call, so # swapping per iteration routes lifecycle lines to the # right job's UI log. - manager.job_logger = job.logger + manager.job_logger = job_logger await manager.log_consumer_stats_snapshot(job.pk) except Exception: errors += 1 @@ -488,10 +491,15 @@ async def _snapshot_all() -> None: try: async_to_sync(_snapshot_all)() except Exception: - logger.exception("Shared-connection snapshot setup failed; skipping this tick") + # Covers both ``__aenter__`` setup failures (no iteration ran) and the + # rare ``__aexit__`` teardown failure after a clean loop. In the + # teardown case this overwrites the per-iteration count with the total + # — accepted: a persistent failure will show up again next tick. + logger.exception("Shared-connection snapshot failed; marking tick unfixable") errors = len(running_jobs) - logger.info( + log_fn = logger.warning if errors else logger.info + log_fn( "running_job_snapshots check: %d running async job(s), %d error(s)", len(running_jobs), errors, @@ -499,6 +507,21 @@ async def _snapshot_all() -> None: return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors) +def _safe_run_sub_check(name: str, fn: Callable[[], IntegrityCheckResult]) -> IntegrityCheckResult: + """Run one umbrella sub-check, returning an ``unfixable=1`` sentinel on failure. + + The umbrella composes independent sub-checks; one failing must not block + the others. A raised exception is logged and surfaced as a single + ``unfixable`` entry so operators watching the task result in Flower see + the check failed rather than reading zero and assuming all-clear. + """ + try: + return fn() + except Exception: + logger.exception("%s sub-check failed; continuing umbrella", name) + return IntegrityCheckResult(checked=0, fixed=0, unfixable=1) + + @celery_app.task(soft_time_limit=300, time_limit=360, expires=_JOBS_HEALTH_BEAT_EXPIRES) def jobs_health_check() -> dict: """Umbrella beat task for periodic job-health checks. @@ -511,8 +534,8 @@ def jobs_health_check() -> dict: it; add new sub-checks by extending that dataclass and calling them here. """ result = JobsHealthCheckResult( - stale_jobs=_run_stale_jobs_check(), - running_job_snapshots=_run_running_job_snapshot_check(), + stale_jobs=_safe_run_sub_check("stale_jobs", _run_stale_jobs_check), + running_job_snapshots=_safe_run_sub_check("running_job_snapshots", _run_running_job_snapshot_check), ) return dataclasses.asdict(result) diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py index 9f6fd9bf2..eaf2f3368 100644 --- a/ami/jobs/tests/test_periodic_beat_tasks.py +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -74,6 +74,11 @@ def test_snapshots_each_running_async_job(self, mock_manager_cls, _mock_cleanup) result = jobs_health_check() self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 0}) + # Observation-only contract: the snapshot sub-check must never report + # ``fixed > 0`` since it does not mutate state. Lock this in explicitly + # so a future refactor that accidentally increments ``fixed`` breaks + # this assertion rather than silently shipping. + self.assertEqual(result["running_job_snapshots"]["fixed"], 0) snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list] self.assertCountEqual(snapshots, [job_a.pk, job_b.pk]) @@ -123,3 +128,48 @@ def test_non_async_running_jobs_are_ignored_by_snapshot_check(self, mock_manager self.assertEqual(result["running_job_snapshots"], _empty_check_dict()) instance.log_consumer_stats_snapshot.assert_not_awaited() + + def test_sub_check_exception_does_not_block_the_other(self, mock_manager_cls, _mock_cleanup): + # One stale job to prove the reconciler would have had work; the + # snapshot sub-check raises and must not prevent the stale-jobs + # sub-check from running and reporting its own result. + self._create_stale_job() + self._stub_manager(mock_manager_cls) + + with patch( + "ami.jobs.tasks._run_running_job_snapshot_check", + side_effect=RuntimeError("pretend the observation check blew up"), + ): + result = jobs_health_check() + + # Stale-jobs sub-check completes normally and reports the reconciliation. + self.assertEqual(result["stale_jobs"], {"checked": 1, "fixed": 1, "unfixable": 0}) + # Snapshot sub-check returns the `unfixable=1` sentinel on failure so + # operators reading the task result see the check failed, not "nothing + # to do." + self.assertEqual(result["running_job_snapshots"], {"checked": 0, "fixed": 0, "unfixable": 1}) + + def test_stale_jobs_fixed_counts_celery_updated_and_revoked_paths(self, mock_manager_cls, _mock_cleanup): + # Two stale jobs in different reconciliation states: one has a Celery + # task_id that returns a terminal state (counts as "updated from Celery"), + # the other has no task_id and is forced to REVOKED. Both contribute to + # `fixed` — this test guards against a refactor dropping one branch. + from celery import states + + job_with_task = self._create_stale_job() + job_with_task.task_id = "terminal-task" + job_with_task.save(update_fields=["task_id"]) + self._create_stale_job() # no task_id → revoked path + self._stub_manager(mock_manager_cls) + + class _FakeAsyncResult: + def __init__(self, task_id): + self.state = states.SUCCESS if task_id == "terminal-task" else states.PENDING + + # `check_stale_jobs` imports AsyncResult locally from celery.result, + # so patch at source rather than at the call site. + with patch("celery.result.AsyncResult", _FakeAsyncResult): + result = jobs_health_check() + + # checked == 2 (both stale), fixed == 2 (one per branch), unfixable == 0 + self.assertEqual(result["stale_jobs"], {"checked": 2, "fixed": 2, "unfixable": 0})