From fd1277f97fb83802e38f04f43381ee753f309bf5 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 20 Apr 2026 12:02:56 -0700 Subject: [PATCH 1/4] fix(jobs): throttle + defer pipeline heartbeat update Move _mark_pipeline_pull_services_seen off the HTTP request path by dispatching a new Celery task (update_pipeline_pull_services_seen) via .delay() from the /tasks and /result endpoints. The task throttles DB writes to once per ~30s per job, cutting concurrent UPDATE pressure under async_api load. Adds 6 unit tests covering the dispatch, throttle, and no-op edge cases. Co-Authored-By: Claude Sonnet 4.6 --- ami/jobs/tasks.py | 51 +++++++++++ ami/jobs/tests/test_jobs.py | 164 ++++++++++++++++++++++++++++++++++++ ami/jobs/views.py | 32 +++---- 3 files changed, 229 insertions(+), 18 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 183cd5186..c984519ed 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -34,6 +34,57 @@ # "nobody's listening" signal. WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5) +# Minimum interval between heartbeat DB writes for a given job. +# PROCESSING_SERVICE_LAST_SEEN_MAX is 60s; writing at most once per 30s keeps +# last_seen current without hammering the same rows on every concurrent request. +HEARTBEAT_THROTTLE_SECONDS = 30 + + +@celery_app.task( + soft_time_limit=10, + time_limit=15, + # No retries — a missed heartbeat is benign; retrying adds load for no gain. +) +def update_pipeline_pull_services_seen(job_id: int) -> None: + """ + Fire-and-forget heartbeat task: record last_seen/last_seen_live for async + (pull-mode) processing services linked to a job's pipeline. + + Called via .delay() from the tasks and result view endpoints so the HTTP + request is never blocked on this DB write. + + Throttle: skips the UPDATE if all matching services were seen within + HEARTBEAT_THROTTLE_SECONDS, cutting write rate under concurrent requests by + orders of magnitude while keeping last_seen fresh relative to the 60s + PROCESSING_SERVICE_LAST_SEEN_MAX threshold. + + Scope: marks ALL async services on the pipeline within this project as live, + not just the specific service that made the request. Once application-token + auth is available (PR #1117), this should be scoped to the individual + calling service instead. + """ + from ami.jobs.models import Job # avoid circular import + + try: + job = Job.objects.select_related("pipeline").get(pk=job_id) + except Job.DoesNotExist: + return + + if not job.pipeline_id: + return + + now = datetime.datetime.now() + throttle_cutoff = now - datetime.timedelta(seconds=HEARTBEAT_THROTTLE_SECONDS) + + services_qs = job.pipeline.processing_services.async_services().filter(projects=job.project_id) + + # Cheap read: skip the UPDATE if every matching service was seen recently. + recent_seen = services_qs.values_list("last_seen", flat=True) + if recent_seen and all(ts is not None and ts >= throttle_cutoff for ts in recent_seen): + return + + services_qs.update(last_seen=now, last_seen_live=True) + @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) def run_job(self, job_id: int) -> None: diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 847a61f7e..4cb4fb484 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -11,6 +11,7 @@ from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.models.processing_service import ProcessingService from ami.ml.orchestration.jobs import queue_images_to_nats from ami.users.models import User @@ -1016,3 +1017,166 @@ def test_tasks_endpoint_rejects_non_async_jobs(self): resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("async_api", resp.json()[0].lower()) + + +class TestPipelineHeartbeatTask(APITestCase): + """ + Unit tests for update_pipeline_pull_services_seen and the view-level + _mark_pipeline_pull_services_seen fire-and-forget dispatch. + """ + + def setUp(self): + self.project = Project.objects.create(name="Heartbeat Test Project") + self.pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline") + self.pipeline.projects.add(self.project) + self.collection = SourceImageCollection.objects.create(name="HB Collection", project=self.project) + self.job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Heartbeat Test Job", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + self.service = ProcessingService.objects.create( + name="Heartbeat Worker", + endpoint_url=None, # None = pull-mode / async service + ) + self.service.pipelines.add(self.pipeline) + self.service.projects.add(self.project) + + def test_tasks_endpoint_dispatches_heartbeat_task(self): + """The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" + from unittest.mock import patch + + job = self.job + job.status = JobState.STARTED + job.save(update_fields=["status"]) + + images = [ + SourceImage.objects.create( + path=f"hb_tasks_{i}.jpg", + public_base_url="http://example.com", + project=self.project, + ) + for i in range(2) + ] + queue_images_to_nats(job, images) + + user = User.objects.create_user(email="hbtest@example.com", is_superuser=True, is_active=True) + self.client.force_authenticate(user=user) + + with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay: + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") + + self.assertEqual(resp.status_code, 200) + mock_delay.assert_called_once_with(job.pk) + + def test_result_endpoint_dispatches_heartbeat_task(self): + """The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" + from unittest.mock import MagicMock, patch + + user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True) + self.client.force_authenticate(user=user) + + result_data = { + "results": [ + { + "reply_subject": "test.reply.hb", + "result": { + "pipeline": "heartbeat-pipeline", + "algorithms": {}, + "total_time": 0.1, + "source_images": [], + "detections": [], + "errors": None, + }, + } + ] + } + + mock_async_result = MagicMock() + mock_async_result.id = "hb-task-id" + with ( + patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result), + patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay, + ): + result_url = reverse_with_params( + "api:job-result", args=[self.job.pk], params={"project_id": self.project.pk} + ) + resp = self.client.post(result_url, result_data, format="json") + + self.assertEqual(resp.status_code, 200) + mock_delay.assert_called_once_with(self.job.pk) + + def test_heartbeat_task_updates_last_seen_when_stale(self): + """update_pipeline_pull_services_seen writes last_seen when the service is stale.""" + import datetime + + from ami.jobs.tasks import update_pipeline_pull_services_seen + + # Set last_seen to well past the throttle window + old_time = datetime.datetime.now() - datetime.timedelta(minutes=5) + self.service.last_seen = old_time + self.service.last_seen_live = False + self.service.save(update_fields=["last_seen", "last_seen_live"]) + + update_pipeline_pull_services_seen(self.job.pk) + + self.service.refresh_from_db() + self.assertTrue(self.service.last_seen_live) + self.assertGreater(self.service.last_seen, old_time) + + def test_heartbeat_task_skips_update_when_recent(self): + """update_pipeline_pull_services_seen skips the UPDATE when last_seen is within the throttle window.""" + import datetime + + from ami.jobs.tasks import update_pipeline_pull_services_seen + + # Set last_seen to just now — well inside the 30s throttle window + recent_time = datetime.datetime.now() - datetime.timedelta(seconds=5) + self.service.last_seen = recent_time + self.service.last_seen_live = True + self.service.save(update_fields=["last_seen", "last_seen_live"]) + + update_pipeline_pull_services_seen(self.job.pk) + + self.service.refresh_from_db() + # last_seen should not have advanced significantly (throttle skipped the UPDATE) + self.assertAlmostEqual( + self.service.last_seen.timestamp(), + recent_time.timestamp(), + delta=1.0, + ) + + def test_heartbeat_task_no_op_for_missing_job(self): + """update_pipeline_pull_services_seen silently returns when job_id does not exist.""" + from ami.jobs.tasks import update_pipeline_pull_services_seen + + # Should not raise + update_pipeline_pull_services_seen(job_id=999999) + + def test_heartbeat_task_no_op_for_job_without_pipeline(self): + """update_pipeline_pull_services_seen returns early when job has no pipeline.""" + import datetime + + from ami.jobs.tasks import update_pipeline_pull_services_seen + + job_no_pipeline = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="No-pipeline job", + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + + old_time = datetime.datetime.now() - datetime.timedelta(minutes=10) + self.service.last_seen = old_time + self.service.save(update_fields=["last_seen"]) + + update_pipeline_pull_services_seen(job_no_pipeline.pk) + + # Service last_seen should be unchanged because the task returned early + self.service.refresh_from_db() + self.assertAlmostEqual(self.service.last_seen.timestamp(), old_time.timestamp(), delta=1.0) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 47c2461b9..cf8e9e08d 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -24,7 +24,7 @@ MLJobTasksRequestSerializer, MLJobTasksResponseSerializer, ) -from ami.jobs.tasks import process_nats_pipeline_result +from ami.jobs.tasks import process_nats_pipeline_result, update_pipeline_pull_services_seen from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param @@ -52,26 +52,22 @@ def _actor_log_context(request) -> tuple[str, str | None]: def _mark_pipeline_pull_services_seen(job: "Job") -> None: """ - Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline. - - Called on every task-fetch and result-submit request so that the worker's polling activity - keeps last_seen/last_seen_live current. The periodic check_processing_services_online task - will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX. - - IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just - the specific service that made the request. If multiple async services share the same pipeline - within a project, a single worker polling will keep all of them appearing online. - Once application-token auth is available (PR #1117), this should be scoped to the individual - calling service instead. + Enqueue a fire-and-forget heartbeat for async (pull-mode) processing services + linked to the job's pipeline. + + Dispatches update_pipeline_pull_services_seen via Celery .delay() so the view + is never blocked on the DB write. The task throttles writes to at most once per + ~30 seconds per job, keeping last_seen current relative to the 60s + PROCESSING_SERVICE_LAST_SEEN_MAX threshold without hammering the same rows on + every concurrent task-fetch or result-submit request. + + Per-service scoping is not yet possible — marks ALL async services on the + pipeline within this project as live. Once application-token auth lands + (PR #1117) this can be scoped to the individual calling service. """ - import datetime - if not job.pipeline_id: return - job.pipeline.processing_services.async_services().filter(projects=job.project_id).update( - last_seen=datetime.datetime.now(), - last_seen_live=True, - ) + update_pipeline_pull_services_seen.delay(job.pk) class JobFilterSet(filters.FilterSet): From 660360eba2d7f5e93981509aeba7d9e070f31d7f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:44:00 -0700 Subject: [PATCH 2/4] fix(jobs): harden throttled pipeline heartbeat updates (#1260) * fix heartbeat review feedback Agent-Logs-Url: https://github.com/RolnickLab/antenna/sessions/bc1907bb-7118-4133-abab-4c4dd852ecc0 Co-authored-by: mihow <158175+mihow@users.noreply.github.com> * refine heartbeat timestamp handling Agent-Logs-Url: https://github.com/RolnickLab/antenna/sessions/bc1907bb-7118-4133-abab-4c4dd852ecc0 Co-authored-by: mihow <158175+mihow@users.noreply.github.com> * align heartbeat timestamps with local time Agent-Logs-Url: https://github.com/RolnickLab/antenna/sessions/bc1907bb-7118-4133-abab-4c4dd852ecc0 Co-authored-by: mihow <158175+mihow@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mihow <158175+mihow@users.noreply.github.com> --- ami/jobs/tasks.py | 30 +++++++++-------- ami/jobs/tests/test_jobs.py | 64 ++++++++++++++++++++++++++++++++----- ami/jobs/views.py | 15 ++++++--- 3 files changed, 84 insertions(+), 25 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index c984519ed..2f6109ef0 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -9,6 +9,7 @@ from asgiref.sync import async_to_sync, sync_to_async from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction +from django.db.models import Q from redis.exceptions import RedisError from ami.main.checks.schemas import IntegrityCheckResult @@ -34,18 +35,22 @@ # "nobody's listening" signal. WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5) -# Minimum interval between heartbeat DB writes for a given job. +# Minimum interval between heartbeat DB writes for a given (pipeline, project). # PROCESSING_SERVICE_LAST_SEEN_MAX is 60s; writing at most once per 30s keeps -# last_seen current without hammering the same rows on every concurrent request. +# shared last_seen rows current without hammering them on every concurrent +# request for the same pipeline within a project. HEARTBEAT_THROTTLE_SECONDS = 30 +HEARTBEAT_TASK_EXPIRES_SECONDS = HEARTBEAT_THROTTLE_SECONDS * 2 @celery_app.task( soft_time_limit=10, time_limit=15, + expires=HEARTBEAT_TASK_EXPIRES_SECONDS, + ignore_result=True, # No retries — a missed heartbeat is benign; retrying adds load for no gain. ) -def update_pipeline_pull_services_seen(job_id: int) -> None: +def update_pipeline_pull_services_seen(job_id: int, seen_at_iso: str | None = None) -> None: """ Fire-and-forget heartbeat task: record last_seen/last_seen_live for async (pull-mode) processing services linked to a job's pipeline. @@ -53,9 +58,10 @@ def update_pipeline_pull_services_seen(job_id: int) -> None: Called via .delay() from the tasks and result view endpoints so the HTTP request is never blocked on this DB write. - Throttle: skips the UPDATE if all matching services were seen within - HEARTBEAT_THROTTLE_SECONDS, cutting write rate under concurrent requests by - orders of magnitude while keeping last_seen fresh relative to the 60s + Throttle: skips the UPDATE if every matching service in the shared + (pipeline, project) scope was seen within HEARTBEAT_THROTTLE_SECONDS, + cutting write rate under concurrent requests by orders of magnitude while + keeping last_seen fresh relative to the 60s PROCESSING_SERVICE_LAST_SEEN_MAX threshold. Scope: marks ALL async services on the pipeline within this project as live, @@ -73,17 +79,15 @@ def update_pipeline_pull_services_seen(job_id: int) -> None: if not job.pipeline_id: return - now = datetime.datetime.now() - throttle_cutoff = now - datetime.timedelta(seconds=HEARTBEAT_THROTTLE_SECONDS) + seen_at = datetime.datetime.fromisoformat(seen_at_iso) if seen_at_iso is not None else datetime.datetime.now() + throttle_cutoff = seen_at - datetime.timedelta(seconds=HEARTBEAT_THROTTLE_SECONDS) services_qs = job.pipeline.processing_services.async_services().filter(projects=job.project_id) - - # Cheap read: skip the UPDATE if every matching service was seen recently. - recent_seen = services_qs.values_list("last_seen", flat=True) - if recent_seen and all(ts is not None and ts >= throttle_cutoff for ts in recent_seen): + stale_services_qs = services_qs.filter(Q(last_seen__isnull=True) | Q(last_seen__lt=throttle_cutoff)) + if not stale_services_qs.exists(): return - services_qs.update(last_seen=now, last_seen_live=True) + stale_services_qs.update(last_seen=seen_at, last_seen_live=True) @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 4cb4fb484..e70a35a93 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -1047,7 +1047,7 @@ def setUp(self): def test_tasks_endpoint_dispatches_heartbeat_task(self): """The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" - from unittest.mock import patch + from unittest.mock import ANY, patch job = self.job job.status = JobState.STARTED @@ -1071,11 +1071,11 @@ def test_tasks_endpoint_dispatches_heartbeat_task(self): resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 200) - mock_delay.assert_called_once_with(job.pk) + mock_delay.assert_called_once_with(job.pk, seen_at_iso=ANY) def test_result_endpoint_dispatches_heartbeat_task(self): """The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" - from unittest.mock import MagicMock, patch + from unittest.mock import ANY, MagicMock, patch user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True) self.client.force_authenticate(user=user) @@ -1108,7 +1108,37 @@ def test_result_endpoint_dispatches_heartbeat_task(self): resp = self.client.post(result_url, result_data, format="json") self.assertEqual(resp.status_code, 200) - mock_delay.assert_called_once_with(self.job.pk) + mock_delay.assert_called_once_with(self.job.pk, seen_at_iso=ANY) + + def test_tasks_endpoint_tolerates_heartbeat_dispatch_failure(self): + """Heartbeat enqueue errors should not fail the /tasks response.""" + from unittest.mock import patch + + from kombu.exceptions import OperationalError + + job = self.job + job.status = JobState.STARTED + job.save(update_fields=["status"]) + + image = SourceImage.objects.create( + path="hb_tasks_broker.jpg", + public_base_url="http://example.com", + project=self.project, + ) + queue_images_to_nats(job, [image]) + + user = User.objects.create_user(email="hbbroker@example.com", is_superuser=True, is_active=True) + self.client.force_authenticate(user=user) + + with patch( + "ami.jobs.views.update_pipeline_pull_services_seen.delay", + side_effect=OperationalError("broker unavailable"), + ): + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") + + self.assertEqual(resp.status_code, 200) + self.assertEqual(len(resp.json()["tasks"]), 1) def test_heartbeat_task_updates_last_seen_when_stale(self): """update_pipeline_pull_services_seen writes last_seen when the service is stale.""" @@ -1122,11 +1152,12 @@ def test_heartbeat_task_updates_last_seen_when_stale(self): self.service.last_seen_live = False self.service.save(update_fields=["last_seen", "last_seen_live"]) - update_pipeline_pull_services_seen(self.job.pk) + seen_at = datetime.datetime.now() + update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=seen_at.isoformat()) self.service.refresh_from_db() self.assertTrue(self.service.last_seen_live) - self.assertGreater(self.service.last_seen, old_time) + self.assertEqual(self.service.last_seen, seen_at) def test_heartbeat_task_skips_update_when_recent(self): """update_pipeline_pull_services_seen skips the UPDATE when last_seen is within the throttle window.""" @@ -1140,7 +1171,7 @@ def test_heartbeat_task_skips_update_when_recent(self): self.service.last_seen_live = True self.service.save(update_fields=["last_seen", "last_seen_live"]) - update_pipeline_pull_services_seen(self.job.pk) + update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=datetime.datetime.now().isoformat()) self.service.refresh_from_db() # last_seen should not have advanced significantly (throttle skipped the UPDATE) @@ -1175,8 +1206,25 @@ def test_heartbeat_task_no_op_for_job_without_pipeline(self): self.service.last_seen = old_time self.service.save(update_fields=["last_seen"]) - update_pipeline_pull_services_seen(job_no_pipeline.pk) + update_pipeline_pull_services_seen(job_no_pipeline.pk, seen_at_iso=datetime.datetime.now().isoformat()) # Service last_seen should be unchanged because the task returned early self.service.refresh_from_db() self.assertAlmostEqual(self.service.last_seen.timestamp(), old_time.timestamp(), delta=1.0) + + def test_heartbeat_task_does_not_regress_newer_last_seen(self): + """Delayed heartbeats must not overwrite a newer last_seen value.""" + import datetime + + from ami.jobs.tasks import update_pipeline_pull_services_seen + + newer_time = datetime.datetime.now() + delayed_seen_at = newer_time - datetime.timedelta(minutes=1) + self.service.last_seen = newer_time + self.service.last_seen_live = True + self.service.save(update_fields=["last_seen", "last_seen_live"]) + + update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=delayed_seen_at.isoformat()) + + self.service.refresh_from_db() + self.assertEqual(self.service.last_seen, newer_time) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index cf8e9e08d..ad89b94af 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,4 +1,5 @@ import asyncio +import datetime import logging import kombu.exceptions @@ -57,9 +58,10 @@ def _mark_pipeline_pull_services_seen(job: "Job") -> None: Dispatches update_pipeline_pull_services_seen via Celery .delay() so the view is never blocked on the DB write. The task throttles writes to at most once per - ~30 seconds per job, keeping last_seen current relative to the 60s - PROCESSING_SERVICE_LAST_SEEN_MAX threshold without hammering the same rows on - every concurrent task-fetch or result-submit request. + ~30 seconds per pipeline within this project, keeping last_seen current + relative to the 60s PROCESSING_SERVICE_LAST_SEEN_MAX threshold without + hammering the same rows on every concurrent task-fetch or result-submit + request. Per-service scoping is not yet possible — marks ALL async services on the pipeline within this project as live. Once application-token auth lands @@ -67,7 +69,12 @@ def _mark_pipeline_pull_services_seen(job: "Job") -> None: """ if not job.pipeline_id: return - update_pipeline_pull_services_seen.delay(job.pk) + try: + update_pipeline_pull_services_seen.delay(job.pk, seen_at_iso=datetime.datetime.now().isoformat()) + except (kombu.exceptions.KombuError, ConnectionError, OSError) as exc: + msg = f"Failed to enqueue non-critical pipeline heartbeat for job {job.pk}: {exc}" + logger.warning(msg) + job.logger.warning(msg) class JobFilterSet(filters.FilterSet): From fd8e3798f81cd10a0d1ece1410d5368cb2bf026a Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 20 Apr 2026 14:54:58 -0700 Subject: [PATCH 3/4] fix(jobs): gate heartbeat dispatch with Redis cache to cut broker churn Add a view-level cache.add() gate in _mark_pipeline_pull_services_seen keyed on (pipeline_id, project_id) with a HEARTBEAT_THROTTLE_SECONDS timeout. Previously every /tasks and /result request enqueued a Celery task whose sole job was usually to check the throttle and return; now most requests skip the enqueue entirely. The task's own stale-row check remains as a safety net under cache eviction. Key layout is intentionally noted to move to heartbeat:service: once per-service identification lands via application-token auth (PR #1117), so one service's poll cannot suppress another's heartbeat. Adds test_view_gate_suppresses_redundant_dispatches; clears cache in setUp so the gate does not leak state between tests. Co-Authored-By: Claude Sonnet 4.6 --- ami/jobs/tests/test_jobs.py | 18 ++++++++++++++++++ ami/jobs/views.py | 27 +++++++++++++++++---------- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index e70a35a93..1dcbd959c 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -1026,6 +1026,12 @@ class TestPipelineHeartbeatTask(APITestCase): """ def setUp(self): + from django.core.cache import cache + + # Cache-based gate in _mark_pipeline_pull_services_seen would otherwise + # carry over between tests and suppress the .delay() we want to assert. + cache.clear() + self.project = Project.objects.create(name="Heartbeat Test Project") self.pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline") self.pipeline.projects.add(self.project) @@ -1228,3 +1234,15 @@ def test_heartbeat_task_does_not_regress_newer_last_seen(self): self.service.refresh_from_db() self.assertEqual(self.service.last_seen, newer_time) + + def test_view_gate_suppresses_redundant_dispatches(self): + """Rapid repeated calls to _mark_pipeline_pull_services_seen should only enqueue once per window.""" + from unittest.mock import patch + + from ami.jobs.views import _mark_pipeline_pull_services_seen + + with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay: + for _ in range(5): + _mark_pipeline_pull_services_seen(self.job) + + self.assertEqual(mock_delay.call_count, 1) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index ad89b94af..6c9f382bc 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -5,6 +5,7 @@ import kombu.exceptions import nats.errors from asgiref.sync import async_to_sync +from django.core.cache import cache from django.db.models import Q from django.db.models.query import QuerySet from django.forms import IntegerField @@ -25,7 +26,7 @@ MLJobTasksRequestSerializer, MLJobTasksResponseSerializer, ) -from ami.jobs.tasks import process_nats_pipeline_result, update_pipeline_pull_services_seen +from ami.jobs.tasks import HEARTBEAT_THROTTLE_SECONDS, process_nats_pipeline_result, update_pipeline_pull_services_seen from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param @@ -57,18 +58,24 @@ def _mark_pipeline_pull_services_seen(job: "Job") -> None: linked to the job's pipeline. Dispatches update_pipeline_pull_services_seen via Celery .delay() so the view - is never blocked on the DB write. The task throttles writes to at most once per - ~30 seconds per pipeline within this project, keeping last_seen current - relative to the 60s PROCESSING_SERVICE_LAST_SEEN_MAX threshold without - hammering the same rows on every concurrent task-fetch or result-submit - request. - - Per-service scoping is not yet possible — marks ALL async services on the - pipeline within this project as live. Once application-token auth lands - (PR #1117) this can be scoped to the individual calling service. + is never blocked on the DB write. A view-level Redis cache gate skips the + .delay() entirely when a heartbeat for the same (pipeline, project) has been + enqueued within HEARTBEAT_THROTTLE_SECONDS — so under concurrent polling we + avoid broker + task churn, not just the DB write. The task itself also + re-checks staleness before writing (belt + suspenders, and safe under cache + eviction). + + Cache key scope: currently `heartbeat:pipeline::project:` + because we cannot yet identify the specific calling service. Once + application-token auth lands (PR #1117), the key should become + `heartbeat:service:` so each service gets its own throttle + window and one service's poll does not suppress another's heartbeat. """ if not job.pipeline_id: return + cache_key = f"heartbeat:pipeline:{job.pipeline_id}:project:{job.project_id}" + if not cache.add(cache_key, 1, timeout=HEARTBEAT_THROTTLE_SECONDS): + return try: update_pipeline_pull_services_seen.delay(job.pk, seen_at_iso=datetime.datetime.now().isoformat()) except (kombu.exceptions.KombuError, ConnectionError, OSError) as exc: From e2f982edeffa87e2ab98a5a0c7f1aa617d720d11 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 20 Apr 2026 15:11:39 -0700 Subject: [PATCH 4/4] refactor(jobs): simplify heartbeat task to rely on view-level gate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The view's Redis cache gate (HEARTBEAT_THROTTLE_SECONDS, added in fd8e3798) is the load-bearing throttle — it skips the whole .delay() when a recent heartbeat has already fired for the same (pipeline, project). With dispatch already gated, the task doesn't need its own throttle/staleness machinery. - Drop seen_at_iso param, Q-filter, .exists() preflight, expires=. - Task is now just .update(last_seen=now(), last_seen_live=True). - Drop 5 task-body tests (staleness, skip-when-recent, no-op-for- missing-job, no-op-for-no-pipeline, regression-guard) that exercised logic we removed. Keep 4: two endpoint dispatch smokes, broker- failure tolerance, and the view-gate suppression test. Crash-safety is unchanged: the write is still off the gunicorn request path, which was the SIGSEGV mitigation. --- ami/jobs/tasks.py | 36 +++++--------- ami/jobs/tests/test_jobs.py | 97 ++----------------------------------- ami/jobs/views.py | 14 ++---- 3 files changed, 20 insertions(+), 127 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 2f6109ef0..629c95828 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -9,7 +9,6 @@ from asgiref.sync import async_to_sync, sync_to_async from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction -from django.db.models import Q from redis.exceptions import RedisError from ami.main.checks.schemas import IntegrityCheckResult @@ -35,34 +34,26 @@ # "nobody's listening" signal. WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5) -# Minimum interval between heartbeat DB writes for a given (pipeline, project). -# PROCESSING_SERVICE_LAST_SEEN_MAX is 60s; writing at most once per 30s keeps -# shared last_seen rows current without hammering them on every concurrent -# request for the same pipeline within a project. +# Minimum interval between heartbeat dispatches for a given (pipeline, project). +# The view-level Redis cache gate uses this window to skip .delay() under +# concurrent polling; the task itself does no throttling. HEARTBEAT_THROTTLE_SECONDS = 30 -HEARTBEAT_TASK_EXPIRES_SECONDS = HEARTBEAT_THROTTLE_SECONDS * 2 @celery_app.task( soft_time_limit=10, time_limit=15, - expires=HEARTBEAT_TASK_EXPIRES_SECONDS, ignore_result=True, # No retries — a missed heartbeat is benign; retrying adds load for no gain. ) -def update_pipeline_pull_services_seen(job_id: int, seen_at_iso: str | None = None) -> None: +def update_pipeline_pull_services_seen(job_id: int) -> None: """ Fire-and-forget heartbeat task: record last_seen/last_seen_live for async (pull-mode) processing services linked to a job's pipeline. - Called via .delay() from the tasks and result view endpoints so the HTTP - request is never blocked on this DB write. - - Throttle: skips the UPDATE if every matching service in the shared - (pipeline, project) scope was seen within HEARTBEAT_THROTTLE_SECONDS, - cutting write rate under concurrent requests by orders of magnitude while - keeping last_seen fresh relative to the 60s - PROCESSING_SERVICE_LAST_SEEN_MAX threshold. + Throttling lives in the view (Redis cache gate over HEARTBEAT_THROTTLE_SECONDS), + so this task is dispatched at most once per (pipeline, project) per window + and can just write. Scope: marks ALL async services on the pipeline within this project as live, not just the specific service that made the request. Once application-token @@ -79,15 +70,10 @@ def update_pipeline_pull_services_seen(job_id: int, seen_at_iso: str | None = No if not job.pipeline_id: return - seen_at = datetime.datetime.fromisoformat(seen_at_iso) if seen_at_iso is not None else datetime.datetime.now() - throttle_cutoff = seen_at - datetime.timedelta(seconds=HEARTBEAT_THROTTLE_SECONDS) - - services_qs = job.pipeline.processing_services.async_services().filter(projects=job.project_id) - stale_services_qs = services_qs.filter(Q(last_seen__isnull=True) | Q(last_seen__lt=throttle_cutoff)) - if not stale_services_qs.exists(): - return - - stale_services_qs.update(last_seen=seen_at, last_seen_live=True) + job.pipeline.processing_services.async_services().filter(projects=job.project_id).update( + last_seen=datetime.datetime.now(), + last_seen_live=True, + ) @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 1dcbd959c..90d1f6baa 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -1053,7 +1053,7 @@ def setUp(self): def test_tasks_endpoint_dispatches_heartbeat_task(self): """The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" - from unittest.mock import ANY, patch + from unittest.mock import patch job = self.job job.status = JobState.STARTED @@ -1077,11 +1077,11 @@ def test_tasks_endpoint_dispatches_heartbeat_task(self): resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 200) - mock_delay.assert_called_once_with(job.pk, seen_at_iso=ANY) + mock_delay.assert_called_once_with(job.pk) def test_result_endpoint_dispatches_heartbeat_task(self): """The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" - from unittest.mock import ANY, MagicMock, patch + from unittest.mock import MagicMock, patch user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True) self.client.force_authenticate(user=user) @@ -1114,7 +1114,7 @@ def test_result_endpoint_dispatches_heartbeat_task(self): resp = self.client.post(result_url, result_data, format="json") self.assertEqual(resp.status_code, 200) - mock_delay.assert_called_once_with(self.job.pk, seen_at_iso=ANY) + mock_delay.assert_called_once_with(self.job.pk) def test_tasks_endpoint_tolerates_heartbeat_dispatch_failure(self): """Heartbeat enqueue errors should not fail the /tasks response.""" @@ -1146,95 +1146,6 @@ def test_tasks_endpoint_tolerates_heartbeat_dispatch_failure(self): self.assertEqual(resp.status_code, 200) self.assertEqual(len(resp.json()["tasks"]), 1) - def test_heartbeat_task_updates_last_seen_when_stale(self): - """update_pipeline_pull_services_seen writes last_seen when the service is stale.""" - import datetime - - from ami.jobs.tasks import update_pipeline_pull_services_seen - - # Set last_seen to well past the throttle window - old_time = datetime.datetime.now() - datetime.timedelta(minutes=5) - self.service.last_seen = old_time - self.service.last_seen_live = False - self.service.save(update_fields=["last_seen", "last_seen_live"]) - - seen_at = datetime.datetime.now() - update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=seen_at.isoformat()) - - self.service.refresh_from_db() - self.assertTrue(self.service.last_seen_live) - self.assertEqual(self.service.last_seen, seen_at) - - def test_heartbeat_task_skips_update_when_recent(self): - """update_pipeline_pull_services_seen skips the UPDATE when last_seen is within the throttle window.""" - import datetime - - from ami.jobs.tasks import update_pipeline_pull_services_seen - - # Set last_seen to just now — well inside the 30s throttle window - recent_time = datetime.datetime.now() - datetime.timedelta(seconds=5) - self.service.last_seen = recent_time - self.service.last_seen_live = True - self.service.save(update_fields=["last_seen", "last_seen_live"]) - - update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=datetime.datetime.now().isoformat()) - - self.service.refresh_from_db() - # last_seen should not have advanced significantly (throttle skipped the UPDATE) - self.assertAlmostEqual( - self.service.last_seen.timestamp(), - recent_time.timestamp(), - delta=1.0, - ) - - def test_heartbeat_task_no_op_for_missing_job(self): - """update_pipeline_pull_services_seen silently returns when job_id does not exist.""" - from ami.jobs.tasks import update_pipeline_pull_services_seen - - # Should not raise - update_pipeline_pull_services_seen(job_id=999999) - - def test_heartbeat_task_no_op_for_job_without_pipeline(self): - """update_pipeline_pull_services_seen returns early when job has no pipeline.""" - import datetime - - from ami.jobs.tasks import update_pipeline_pull_services_seen - - job_no_pipeline = Job.objects.create( - job_type_key=MLJob.key, - project=self.project, - name="No-pipeline job", - source_image_collection=self.collection, - dispatch_mode=JobDispatchMode.ASYNC_API, - ) - - old_time = datetime.datetime.now() - datetime.timedelta(minutes=10) - self.service.last_seen = old_time - self.service.save(update_fields=["last_seen"]) - - update_pipeline_pull_services_seen(job_no_pipeline.pk, seen_at_iso=datetime.datetime.now().isoformat()) - - # Service last_seen should be unchanged because the task returned early - self.service.refresh_from_db() - self.assertAlmostEqual(self.service.last_seen.timestamp(), old_time.timestamp(), delta=1.0) - - def test_heartbeat_task_does_not_regress_newer_last_seen(self): - """Delayed heartbeats must not overwrite a newer last_seen value.""" - import datetime - - from ami.jobs.tasks import update_pipeline_pull_services_seen - - newer_time = datetime.datetime.now() - delayed_seen_at = newer_time - datetime.timedelta(minutes=1) - self.service.last_seen = newer_time - self.service.last_seen_live = True - self.service.save(update_fields=["last_seen", "last_seen_live"]) - - update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=delayed_seen_at.isoformat()) - - self.service.refresh_from_db() - self.assertEqual(self.service.last_seen, newer_time) - def test_view_gate_suppresses_redundant_dispatches(self): """Rapid repeated calls to _mark_pipeline_pull_services_seen should only enqueue once per window.""" from unittest.mock import patch diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 6c9f382bc..ec9d64481 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,5 +1,4 @@ import asyncio -import datetime import logging import kombu.exceptions @@ -57,13 +56,10 @@ def _mark_pipeline_pull_services_seen(job: "Job") -> None: Enqueue a fire-and-forget heartbeat for async (pull-mode) processing services linked to the job's pipeline. - Dispatches update_pipeline_pull_services_seen via Celery .delay() so the view - is never blocked on the DB write. A view-level Redis cache gate skips the - .delay() entirely when a heartbeat for the same (pipeline, project) has been - enqueued within HEARTBEAT_THROTTLE_SECONDS — so under concurrent polling we - avoid broker + task churn, not just the DB write. The task itself also - re-checks staleness before writing (belt + suspenders, and safe under cache - eviction). + A Redis cache gate skips the dispatch when a heartbeat for the same + (pipeline, project) has already fired within HEARTBEAT_THROTTLE_SECONDS, + so under concurrent polling we avoid broker + task churn. The Celery task + keeps the DB write off the HTTP request path. Cache key scope: currently `heartbeat:pipeline::project:` because we cannot yet identify the specific calling service. Once @@ -77,7 +73,7 @@ def _mark_pipeline_pull_services_seen(job: "Job") -> None: if not cache.add(cache_key, 1, timeout=HEARTBEAT_THROTTLE_SECONDS): return try: - update_pipeline_pull_services_seen.delay(job.pk, seen_at_iso=datetime.datetime.now().isoformat()) + update_pipeline_pull_services_seen.delay(job.pk) except (kombu.exceptions.KombuError, ConnectionError, OSError) as exc: msg = f"Failed to enqueue non-critical pipeline heartbeat for job {job.pk}: {exc}" logger.warning(msg)