From b60eab0ff435e32b0264b182b4022864cdce71f0 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 16 Jan 2026 11:25:40 -0800 Subject: [PATCH 01/30] merge --- requirements/base.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..ed40ea5f7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -52,6 +52,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing From bc908aa46b2aafc12e89ee4956fd22cc7a9be404 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 20 Feb 2026 18:23:57 -0800 Subject: [PATCH 02/30] fix: PSv2 follow-up fixes from integration tests (#1135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: prevent NATS connection flooding and stale job task fetching - Add connect_timeout=5, allow_reconnect=False to NATS connections to prevent leaked reconnection loops from blocking Django's event loop - Guard /tasks endpoint against terminal-status jobs (return empty tasks instead of attempting NATS reserve) - IncompleteJobFilter now excludes jobs by top-level status in addition to progress JSON stages - Add stale worker cleanup to integration test script Found during PSv2 integration testing where stale ADC workers with default DataLoader parallelism overwhelmed the single uvicorn worker thread by flooding /tasks with concurrent NATS reserve requests. Co-Authored-By: Claude * docs: PSv2 integration test session notes and NATS flooding findings Session notes from 2026-02-16 integration test including root cause analysis of stale worker task competition and NATS connection issues. Findings doc tracks applied fixes and remaining TODOs with priorities. Co-Authored-By: Claude * docs: update session notes with successful test run #3 PSv2 integration test passed end-to-end (job 1380, 20/20 images). Identified ack_wait=300s as cause of ~5min idle time when GPU processes race for NATS tasks. Co-Authored-By: Claude * fix: batch NATS task fetch to prevent HTTP timeouts Replace N×1 reserve_task() calls with single reserve_tasks() batch fetch. The previous implementation created a new pull subscription per message (320 NATS round trips for batch=64), causing the /tasks endpoint to exceed HTTP client timeouts. The new approach uses one psub.fetch() call for the entire batch. Co-Authored-By: Claude * docs: add next session prompt * feat: add pipeline__slug__in filter for multi-pipeline job queries Workers that handle multiple pipelines can now fetch jobs for all of them in a single request: ?pipeline__slug__in=slug1,slug2 Co-Authored-By: Claude * chore: remove local-only docs and scripts from branch These files are session notes, planning docs, and test scripts that should stay local rather than be part of the PR. Co-Authored-By: Claude * feat: set job dispatch_mode at creation time based on project feature flags ML jobs with a pipeline now get dispatch_mode set during setup() instead of waiting until run() is called by the Celery worker. This lets the UI show the correct mode immediately after job creation. Co-Authored-By: Claude * fix: add timeouts to all JetStream operations and restore reconnect policy Add NATS_JETSTREAM_TIMEOUT (10s) to all JetStream metadata operations via asyncio.wait_for() so a hung NATS connection fails fast instead of blocking the caller's thread indefinitely. Also restore the intended reconnect policy (2 attempts, 1s wait) that was lost in a prior force push. Co-Authored-By: Claude * fix: propagate NATS timeouts as 503 instead of swallowing them asyncio.TimeoutError from _ensure_stream() and _ensure_consumer() was caught by the broad `except Exception` in reserve_tasks(), silently returning [] and making NATS outages indistinguishable from empty queues. Workers would then poll immediately, recreating the flooding problem. - Add explicit `except asyncio.TimeoutError: raise` in reserve_tasks() - Catch TimeoutError and OSError in the /tasks view, return 503 - Restore allow_reconnect=False (fail-fast on connection issues) - Add return type annotation to get_connection() Co-Authored-By: Claude * fix: address review comments (log level, fetch timeout, docstring) - Downgrade reserve_tasks log to DEBUG when zero tasks reserved (avoid log spam from frequent polling) - Pass timeout=0.5 from /tasks endpoint to avoid blocking the worker for 5s on empty queues - Fix docstring examples using string 'job123' for int-typed job_id Co-Authored-By: Claude * fix: catch nats.errors.Error in /tasks endpoint for proper 503 responses NoServersError, ConnectionClosedError, and other NATS exceptions inherit from nats.errors.Error (not OSError), so they escaped the handler and returned 500 instead of 503. Co-Authored-By: Claude --------- Co-authored-by: Claude --- ami/jobs/models.py | 15 +- ami/jobs/tests.py | 69 ++++++++- ami/jobs/views.py | 30 ++-- ami/ml/orchestration/nats_queue.py | 137 +++++++++++------- ami/ml/orchestration/tests/test_nats_queue.py | 63 +++++--- 5 files changed, 220 insertions(+), 94 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index b4df41a04..be797dd4f 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -461,9 +461,7 @@ def run(cls, job: "Job"): # End image collection stage job.save() - if job.project.feature_flags.async_pipeline_workers: - job.dispatch_mode = JobDispatchMode.ASYNC_API - job.save(update_fields=["dispatch_mode"]) + if job.dispatch_mode == JobDispatchMode.ASYNC_API: queued = queue_images_to_nats(job, images) if not queued: job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) @@ -473,8 +471,6 @@ def run(cls, job: "Job"): job.save() return else: - job.dispatch_mode = JobDispatchMode.SYNC_API - job.save(update_fields=["dispatch_mode"]) cls.process_images(job, images) @classmethod @@ -919,6 +915,15 @@ def setup(self, save=True): self.progress.add_stage_param(delay_stage.key, "Mood", "😴") if self.pipeline: + # Set dispatch mode based on project feature flags at creation time + # so the UI can show the correct mode before the job runs. + # Only override if still at the default (INTERNAL), to allow explicit overrides. + if self.dispatch_mode == JobDispatchMode.INTERNAL: + if self.project and self.project.feature_flags.async_pipeline_workers: + self.dispatch_mode = JobDispatchMode.ASYNC_API + else: + self.dispatch_mode = JobDispatchMode.SYNC_API + collect_stage = self.progress.add_stage("Collect") self.progress.add_stage_param(collect_stage.key, "Total Images", "") diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 7902faeb1..033a08b5c 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -384,6 +384,36 @@ def test_filter_by_pipeline_slug(self): self.assertEqual(data["count"], 1) self.assertEqual(data["results"][0]["id"], job_with_pipeline.pk) + def test_filter_by_pipeline_slug_in(self): + """Test filtering jobs by pipeline__slug__in (multiple slugs).""" + pipeline_a = self._create_pipeline("Pipeline A", "pipeline-a") + pipeline_b = Pipeline.objects.create(name="Pipeline B", slug="pipeline-b", description="B") + pipeline_b.projects.add(self.project) + pipeline_c = Pipeline.objects.create(name="Pipeline C", slug="pipeline-c", description="C") + pipeline_c.projects.add(self.project) + + job_a = self._create_ml_job("Job A", pipeline_a) + job_b = self._create_ml_job("Job B", pipeline_b) + job_c = self._create_ml_job("Job C", pipeline_c) + + self.client.force_authenticate(user=self.user) + + # Filter for two of the three pipelines + jobs_list_url = reverse_with_params( + "api:job-list", + params={"project_id": self.project.pk, "pipeline__slug__in": "pipeline-a,pipeline-b"}, + ) + resp = self.client.get(jobs_list_url) + + self.assertEqual(resp.status_code, 200) + data = resp.json() + returned_ids = {job["id"] for job in data["results"]} + self.assertIn(job_a.pk, returned_ids) + self.assertIn(job_b.pk, returned_ids) + self.assertNotIn(job_c.pk, returned_ids) + # Original setUp job (no pipeline) should also be excluded + self.assertNotIn(self.job.pk, returned_ids) + def test_search_jobs(self): """Test searching jobs by name and pipeline name.""" pipeline = self._create_pipeline("SearchablePipeline", "searchable-pipeline") @@ -571,13 +601,11 @@ def test_dispatch_mode_filtering(self): dispatch_mode=JobDispatchMode.ASYNC_API, ) - # Create a job with default dispatch_mode (should be "internal") + # Create a non-ML job without a pipeline (dispatch_mode stays "internal") internal_job = Job.objects.create( - job_type_key=MLJob.key, + job_type_key="data_storage_sync", project=self.project, name="Internal Job", - pipeline=self.pipeline, - source_image_collection=self.source_image_collection, ) self.client.force_authenticate(user=self.user) @@ -614,6 +642,39 @@ def test_dispatch_mode_filtering(self): expected_ids = {sync_job.pk, async_job.pk, internal_job.pk} self.assertEqual(returned_ids, expected_ids) + def test_ml_job_dispatch_mode_set_on_creation(self): + """Test that ML jobs get dispatch_mode set based on project feature flags at creation time.""" + # Without async flag, ML job should default to sync_api + sync_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Auto Sync Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + self.assertEqual(sync_job.dispatch_mode, JobDispatchMode.SYNC_API) + + # Enable async flag on project + self.project.feature_flags.async_pipeline_workers = True + self.project.save() + + async_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Auto Async Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + self.assertEqual(async_job.dispatch_mode, JobDispatchMode.ASYNC_API) + + # Non-pipeline job should stay internal regardless of feature flag + internal_job = Job.objects.create( + job_type_key="data_storage_sync", + project=self.project, + name="Internal Job", + ) + self.assertEqual(internal_job.dispatch_mode, JobDispatchMode.INTERNAL) + def test_tasks_endpoint_rejects_non_async_jobs(self): """Test that /tasks endpoint returns 400 for non-async_api jobs.""" from ami.base.serializers import reverse_with_params diff --git a/ami/jobs/views.py b/ami/jobs/views.py index dd8da01b2..ddc1e57a7 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,5 +1,7 @@ +import asyncio import logging +import nats.errors import pydantic from asgiref.sync import async_to_sync from django.db.models import Q @@ -32,6 +34,7 @@ class JobFilterSet(filters.FilterSet): """Custom filterset to enable pipeline name filtering.""" pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact") + pipeline__slug__in = filters.BaseInFilter(field_name="pipeline__slug", lookup_expr="in") class Meta: model = Job @@ -55,11 +58,12 @@ def filter_queryset(self, request, queryset, view): incomplete_only = url_boolean_param(request, "incomplete_only", default=False) # Filter to incomplete jobs if requested (checks "results" stage status) if incomplete_only: - # Create filters for each final state to exclude + # Exclude jobs with a terminal top-level status + queryset = queryset.exclude(status__in=JobState.final_states()) + + # Also exclude jobs where the "results" stage has a final state status final_states = JobState.final_states() exclude_conditions = Q() - - # Exclude jobs where the "results" stage has a final state status for state in final_states: # JSON path query to check if results stage status is in final states # @TODO move to a QuerySet method on Job model if/when this needs to be reused elsewhere @@ -233,6 +237,10 @@ def tasks(self, request, pk=None): if job.dispatch_mode != JobDispatchMode.ASYNC_API: raise ValidationError("Only async_api jobs have fetchable tasks") + # Don't fetch tasks from completed/failed/revoked jobs + if job.status in JobState.final_states(): + return Response({"tasks": []}) + # Validate that the job has a pipeline if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") @@ -241,16 +249,14 @@ def tasks(self, request, pk=None): from ami.ml.orchestration.nats_queue import TaskQueueManager async def get_tasks(): - tasks = [] async with TaskQueueManager() as manager: - for _ in range(batch): - task = await manager.reserve_task(job.pk, timeout=0.1) - if task: - tasks.append(task.dict()) - return tasks - - # Use async_to_sync to properly handle the async call - tasks = async_to_sync(get_tasks)() + return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)] + + try: + tasks = async_to_sync(get_tasks)() + except (asyncio.TimeoutError, OSError, nats.errors.Error) as e: + logger.warning("NATS unavailable while fetching tasks for job %s: %s", job.pk, e) + return Response({"error": "Task queue temporarily unavailable"}, status=503) return Response({"tasks": tasks}) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index fa7188627..65b6f6f72 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -10,6 +10,7 @@ support the visibility timeout semantics we want or a disconnected mode of pulling and ACKing tasks. """ +import asyncio import json import logging @@ -22,9 +23,21 @@ logger = logging.getLogger(__name__) - -async def get_connection(nats_url: str): - nc = await nats.connect(nats_url) +# Timeout for individual JetStream metadata operations (create/check stream and consumer). +# These are lightweight NATS server operations that complete in milliseconds under normal +# conditions. stream_info() and add_stream() don't accept a native timeout parameter, so +# we use asyncio.wait_for() uniformly for all operations. Without these timeouts, a hung +# NATS connection blocks the caller's thread indefinitely — and when that caller is a +# Django worker (via async_to_sync), it makes the entire server unresponsive. +NATS_JETSTREAM_TIMEOUT = 10 # seconds + + +async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]: + nc = await nats.connect( + nats_url, + connect_timeout=5, + allow_reconnect=False, + ) js = nc.jetstream() return nc, js @@ -38,9 +51,9 @@ class TaskQueueManager: Use as an async context manager: async with TaskQueueManager() as manager: - await manager.publish_task('job123', {'data': 'value'}) - task = await manager.reserve_task('job123') - await manager.acknowledge_task(task['reply_subject']) + await manager.publish_task(123, {'data': 'value'}) + tasks = await manager.reserve_tasks(123, count=64) + await manager.acknowledge_task(tasks[0].reply_subject) """ def __init__(self, nats_url: str | None = None): @@ -83,15 +96,20 @@ async def _ensure_stream(self, job_id: int): subject = self._get_subject(job_id) try: - await self.js.stream_info(stream_name) + await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) logger.debug(f"Stream {stream_name} already exists") + except asyncio.TimeoutError: + raise # NATS unreachable — let caller handle it rather than creating a stream blindly except Exception as e: logger.warning(f"Stream {stream_name} does not exist: {e}") # Stream doesn't exist, create it - await self.js.add_stream( - name=stream_name, - subjects=[subject], - max_age=86400, # 24 hours retention + await asyncio.wait_for( + self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ), + timeout=NATS_JETSTREAM_TIMEOUT, ) logger.info(f"Created stream {stream_name}") @@ -105,21 +123,29 @@ async def _ensure_consumer(self, job_id: int): subject = self._get_subject(job_id) try: - info = await self.js.consumer_info(stream_name, consumer_name) + info = await asyncio.wait_for( + self.js.consumer_info(stream_name, consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) logger.debug(f"Consumer {consumer_name} already exists: {info}") + except asyncio.TimeoutError: + raise # NATS unreachable — let caller handle it except Exception: # Consumer doesn't exist, create it - await self.js.add_consumer( - stream=stream_name, - config=ConsumerConfig( - durable_name=consumer_name, - ack_policy=AckPolicy.EXPLICIT, - ack_wait=TASK_TTR, # Visibility timeout (TTR) - max_deliver=5, # Max retry attempts - deliver_policy=DeliverPolicy.ALL, - max_ack_pending=100, # Max unacked messages - filter_subject=subject, + await asyncio.wait_for( + self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=TASK_TTR, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=100, # Max unacked messages + filter_subject=subject, + ), ), + timeout=NATS_JETSTREAM_TIMEOUT, ) logger.info(f"Created consumer {consumer_name}") @@ -147,7 +173,7 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: task_data = json.dumps(data.dict()) # Publish to JetStream - ack = await self.js.publish(subject, task_data.encode()) + ack = await self.js.publish(subject, task_data.encode(), timeout=NATS_JETSTREAM_TIMEOUT) logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") return True @@ -156,62 +182,57 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") return False - async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: + async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> list[PipelineProcessingTask]: """ - Reserve a task from the specified stream. + Reserve up to `count` tasks from the specified stream in a single NATS fetch. Args: job_id: The job ID (integer primary key) to pull tasks from - timeout: Timeout in seconds for reservation (default: 5 seconds) + count: Maximum number of tasks to reserve + timeout: Timeout in seconds waiting for messages (default: 5 seconds) Returns: - PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available + List of PipelineProcessingTask objects with reply_subject set for acknowledgment. + May return fewer than `count` if the queue has fewer messages available. """ if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") - if timeout is None: - timeout = 5 - try: - # Ensure stream and consumer exist await self._ensure_stream(job_id) await self._ensure_consumer(job_id) consumer_name = self._get_consumer_name(job_id) subject = self._get_subject(job_id) - # Create ephemeral subscription for this pull psub = await self.js.pull_subscribe(subject, consumer_name) try: - # Fetch a single message - msgs = await psub.fetch(1, timeout=timeout) - - if msgs: - msg = msgs[0] - task_data = json.loads(msg.data.decode()) - metadata = msg.metadata - - # Parse the task data into PipelineProcessingTask - task = PipelineProcessingTask(**task_data) - # Set the reply_subject for acknowledgment - task.reply_subject = msg.reply - - logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") - return task - + msgs = await psub.fetch(count, timeout=timeout) except nats.errors.TimeoutError: - # No messages available logger.debug(f"No tasks available in stream for job '{job_id}'") - return None + return [] finally: - # Always unsubscribe await psub.unsubscribe() + tasks = [] + for msg in msgs: + task_data = json.loads(msg.data.decode()) + task = PipelineProcessingTask(**task_data) + task.reply_subject = msg.reply + tasks.append(task) + + if tasks: + logger.info(f"Reserved {len(tasks)} tasks from stream for job '{job_id}'") + else: + logger.debug(f"No tasks reserved from stream for job '{job_id}'") + return tasks + + except asyncio.TimeoutError: + raise # NATS unreachable — propagate so the view can return an appropriate error except Exception as e: - logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") - return None + logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}") + return [] async def acknowledge_task(self, reply_subject: str) -> bool: """ @@ -251,7 +272,10 @@ async def delete_consumer(self, job_id: int) -> bool: stream_name = self._get_stream_name(job_id) consumer_name = self._get_consumer_name(job_id) - await self.js.delete_consumer(stream_name, consumer_name) + await asyncio.wait_for( + self.js.delete_consumer(stream_name, consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") return True except Exception as e: @@ -274,7 +298,10 @@ async def delete_stream(self, job_id: int) -> bool: try: stream_name = self._get_stream_name(job_id) - await self.js.delete_stream(stream_name) + await asyncio.wait_for( + self.js.delete_stream(stream_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) logger.info(f"Deleted stream {stream_name} for job '{job_id}'") return True except Exception as e: diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 0cd2c3bef..a7bd91b68 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -62,47 +62,74 @@ async def test_publish_task_creates_stream_and_consumer(self): self.assertIn("job_456", str(js.add_stream.call_args)) js.add_consumer.assert_called_once() - async def test_reserve_task_success(self): - """Test successful task reservation.""" + async def test_reserve_tasks_success(self): + """Test successful batch task reservation.""" nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() - # Mock message with task data - mock_msg = MagicMock() - mock_msg.data = sample_task.json().encode() - mock_msg.reply = "reply.subject.123" - mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) + # Mock messages with task data + mock_msg1 = MagicMock() + mock_msg1.data = sample_task.json().encode() + mock_msg1.reply = "reply.subject.1" + + mock_msg2 = MagicMock() + mock_msg2.data = sample_task.json().encode() + mock_msg2.reply = "reply.subject.2" mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.fetch = AsyncMock(return_value=[mock_msg1, mock_msg2]) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + tasks = await manager.reserve_tasks(123, count=5) - self.assertIsNotNone(task) - self.assertEqual(task.id, sample_task.id) - self.assertEqual(task.reply_subject, "reply.subject.123") + self.assertEqual(len(tasks), 2) + self.assertEqual(tasks[0].id, sample_task.id) + self.assertEqual(tasks[0].reply_subject, "reply.subject.1") + self.assertEqual(tasks[1].reply_subject, "reply.subject.2") + mock_psub.fetch.assert_called_once_with(5, timeout=5) mock_psub.unsubscribe.assert_called_once() - async def test_reserve_task_no_messages(self): - """Test reserve_task when no messages are available.""" + async def test_reserve_tasks_no_messages(self): + """Test reserve_tasks when no messages are available (timeout).""" nc, js = self._create_mock_nats_connection() + import nats.errors mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[]) + mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + tasks = await manager.reserve_tasks(123, count=5) - self.assertIsNone(task) + self.assertEqual(tasks, []) mock_psub.unsubscribe.assert_called_once() + async def test_reserve_tasks_single(self): + """Test reserving a single task.""" + nc, js = self._create_mock_nats_connection() + sample_task = self._create_sample_task() + + mock_msg = MagicMock() + mock_msg.data = sample_task.json().encode() + mock_msg.reply = "reply.subject.123" + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + tasks = await manager.reserve_tasks(123, count=1) + + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].reply_subject, "reply.subject.123") + async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" nc, js = self._create_mock_nats_connection() @@ -144,7 +171,7 @@ async def test_operations_without_connection_raise_error(self): await manager.publish_task(123, sample_task) with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.reserve_task(123) + await manager.reserve_tasks(123, count=1) with self.assertRaisesRegex(RuntimeError, "Connection is not open"): await manager.delete_stream(123) From 4c3802aa77648c0f2a59d2d3e944e2827c2dc742 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 20 Feb 2026 19:29:34 -0800 Subject: [PATCH 03/30] PSv2: Improve task fetching & web worker concurrency configuration (#1142) * feat: configurable NATS tuning and gunicorn worker management Rebase onto main after #1135 merge. Keep only the additions unique to this branch: - Make TASK_TTR configurable via NATS_TASK_TTR Django setting (default 30s) - Make max_ack_pending configurable via NATS_MAX_ACK_PENDING setting (default 100) - Local dev: switch to gunicorn+UvicornWorker by default for production parity, with USE_UVICORN=1 escape hatch for raw uvicorn - Production: auto-detect WEB_CONCURRENCY from CPU cores (capped at 8) when not explicitly set in the environment Co-Authored-By: Claude * fix: address PR review comments - Fix max_ack_pending falsy-zero guard (use `is not None` instead of `or`) - Update TaskQueueManager docstring with Args section - Simplify production WEB_CONCURRENCY fallback (just use nproc) Co-Authored-By: Claude --------- Co-authored-by: Michael Bunsen Co-authored-by: Claude --- ami/ml/orchestration/nats_queue.py | 14 +++++++++++--- compose/local/django/start | 24 ++++++++++++++++++++---- compose/production/django/start | 8 ++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 65b6f6f72..489a0dc66 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -42,13 +42,18 @@ async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]: return nc, js -TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds +TASK_TTR = getattr(settings, "NATS_TASK_TTR", 30) # Visibility timeout in seconds (configurable) class TaskQueueManager: """ Manager for NATS JetStream task queue operations. + Args: + nats_url: NATS server URL. Falls back to settings.NATS_URL, then "nats://nats:4222". + max_ack_pending: Max unacknowledged messages per consumer. Falls back to + settings.NATS_MAX_ACK_PENDING, then 100. + Use as an async context manager: async with TaskQueueManager() as manager: await manager.publish_task(123, {'data': 'value'}) @@ -56,8 +61,11 @@ class TaskQueueManager: await manager.acknowledge_task(tasks[0].reply_subject) """ - def __init__(self, nats_url: str | None = None): + def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = None): self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") + self.max_ack_pending = ( + max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 100) + ) self.nc: nats.NATS | None = None self.js: JetStreamContext | None = None @@ -141,7 +149,7 @@ async def _ensure_consumer(self, job_id: int): ack_wait=TASK_TTR, # Visibility timeout (TTR) max_deliver=5, # Max retry attempts deliver_policy=DeliverPolicy.ALL, - max_ack_pending=100, # Max unacked messages + max_ack_pending=self.max_ack_pending, filter_subject=subject, ), ), diff --git a/compose/local/django/start b/compose/local/django/start index 4eaa76436..5a2607cbd 100755 --- a/compose/local/django/start +++ b/compose/local/django/start @@ -6,10 +6,26 @@ set -o nounset python manage.py migrate -# Launch VS Code debug server if DEBUGGER environment variable is set to 1 -# Note that the --reload flag is not compatible with debugpy, so manually restart the server when code changes +# Set USE_UVICORN=1 to use the original raw uvicorn dev server instead of gunicorn +if [ "${USE_UVICORN:-0}" = "1" ]; then + if [ "${DEBUGGER:-0}" = "1" ]; then + exec python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:5678 -m uvicorn config.asgi:application --host 0.0.0.0 + else + exec uvicorn config.asgi:application --host 0.0.0.0 --reload --reload-include '*.html' + fi +fi + +# Gunicorn with UvicornWorker (production-parity mode, now the default) +# WEB_CONCURRENCY controls worker count (default: 1 for dev with auto-reload) +WORKERS=${WEB_CONCURRENCY:-1} + if [ "${DEBUGGER:-0}" = "1" ]; then - exec python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:5678 -m uvicorn config.asgi:application --host 0.0.0.0 + echo "Starting Gunicorn with debugpy (1 worker)..." + exec python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:5678 -m gunicorn config.asgi --bind 0.0.0.0:8000 --workers 1 -k uvicorn.workers.UvicornWorker +elif [ "$WORKERS" -eq 1 ]; then + echo "Starting Gunicorn with 1 worker (auto-reload enabled)..." + exec gunicorn config.asgi --bind 0.0.0.0:8000 --workers 1 -k uvicorn.workers.UvicornWorker --reload else - exec uvicorn config.asgi:application --host 0.0.0.0 --reload --reload-include '*.html' + echo "Starting Gunicorn with $WORKERS workers..." + exec gunicorn config.asgi --bind 0.0.0.0:8000 --workers "$WORKERS" -k uvicorn.workers.UvicornWorker fi diff --git a/compose/production/django/start b/compose/production/django/start index 5dcb00b5a..5a772895a 100644 --- a/compose/production/django/start +++ b/compose/production/django/start @@ -6,4 +6,12 @@ set -o nounset python /app/manage.py collectstatic --noinput +# Gunicorn natively reads WEB_CONCURRENCY as its --workers default. +# If not set, default to CPU core count. +if [ -z "${WEB_CONCURRENCY:-}" ]; then + export WEB_CONCURRENCY=$(nproc) +fi + +echo "Starting Gunicorn with $WEB_CONCURRENCY worker(s)..." + exec newrelic-admin run-program /usr/local/bin/gunicorn config.asgi --bind 0.0.0.0:5000 --chdir=/app -k uvicorn.workers.UvicornWorker From b717e802a33ce816745304920c359ae577c11672 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 20 Feb 2026 20:47:14 -0800 Subject: [PATCH 04/30] fix: include pipeline_slug in MinimalJobSerializer (#1148) * fix: include pipeline_slug in MinimalJobSerializer (ids_only response) The ADC worker fetches jobs with ids_only=1 and expects pipeline_slug in the response to know which pipeline to run. Without it, Pydantic validation fails and the worker skips the job. Co-Authored-By: Claude * Update ami/jobs/serializers.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Claude Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ami/jobs/serializers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index 7a3471003..d903b0812 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -158,7 +158,8 @@ class Meta(JobListSerializer.Meta): class MinimalJobSerializer(DefaultSerializer): """Minimal serializer returning only essential job fields.""" + pipeline_slug = serializers.CharField(source="pipeline.slug", read_only=True, allow_null=True) + class Meta: model = Job - # Add other fields when needed, e.g: "name", "status", "created_at" - fields = ["id"] + fields = ["id", "pipeline_slug"] From 8df89be61c0b85582450d44654e1afe467e9536d Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 24 Feb 2026 15:37:03 -0800 Subject: [PATCH 05/30] Avoid redis based locking by using atomic updates --- ami/jobs/tasks.py | 28 ++-- ami/jobs/test_tasks.py | 57 ++++--- ami/ml/orchestration/async_job_state.py | 183 ++++++++++----------- ami/ml/orchestration/tests/test_cleanup.py | 15 +- ami/ml/tests.py | 88 +++++----- 5 files changed, 175 insertions(+), 196 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 0abf85dae..6d385ba5d 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -84,15 +84,10 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub state_manager = AsyncJobStateManager(job_id) - progress_info = state_manager.update_state( - processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids - ) + progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids) if not progress_info: - logger.warning( - f"Another task is already processing results for job {job_id}. " - f"Retrying task {self.request.id} in 5 seconds..." - ) - raise self.retry(countdown=5, max_retries=10) + logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + return try: complete_state = JobState.SUCCESS @@ -150,15 +145,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub progress_info = state_manager.update_state( processed_image_ids, stage="results", - request_id=self.request.id, ) if not progress_info: - logger.warning( - f"Another task is already processing results for job {job_id}. " - f"Retrying task {self.request.id} in 5 seconds..." - ) - raise self.retry(countdown=5, max_retries=10) + logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + return # update complete state based on latest progress info after saving results complete_state = JobState.SUCCESS @@ -256,6 +247,15 @@ def _update_job_progress( state_params["classifications"] = current_classifications + new_classifications state_params["captures"] = current_captures + new_captures + # Don't overwrite a stage with a stale progress value. + # This guards against the race where a slower worker calls _update_job_progress + # after a faster worker has already marked further progress + try: + existing_stage = job.progress.get_stage(stage) + progress_percentage = max(existing_stage.progress, progress_percentage) + except (ValueError, AttributeError): + pass # Stage doesn't exist yet; proceed normally + job.progress.update_stage( stage, status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index b37940cdd..1a86f2e95 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -17,7 +17,7 @@ from ami.jobs.tasks import process_nats_pipeline_result from ami.main.models import Detection, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline -from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse from ami.users.models import User @@ -237,38 +237,43 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class): self.assertEqual(mock_manager.acknowledge_task.call_count, 3) @patch("ami.jobs.tasks.TaskQueueManager") - def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manager_class): + def test_process_nats_pipeline_result_concurrent_updates(self, mock_manager_class): """ - Test that error results respect locking mechanism. + Test that concurrent workers update state independently without contention. - Verifies race condition handling when multiple workers - process error results simultaneously. + Without a lock, two workers processing different images can both call + update_state and receive valid progress — no retry needed, no blocking. """ - # Simulate lock held by another task - lock_key = _lock_key(self.job.pk) - cache.set(lock_key, "other-task-id", timeout=60) + mock_manager = self._setup_mock_nats(mock_manager_class) - # Create error result - error_data = self._create_error_result(image_id=str(self.images[0].pk)) - reply_subject = "tasks.reply.test789" - - # Task should raise retry exception when lock not acquired - # The task internally calls self.retry() which raises a Retry exception - from celery.exceptions import Retry - - with self.assertRaises(Retry): - process_nats_pipeline_result.apply( - kwargs={ - "job_id": self.job.pk, - "result_data": error_data, - "reply_subject": reply_subject, - } - ) + # Worker 1 processes images[0] + result_1 = process_nats_pipeline_result.apply( + kwargs={ + "job_id": self.job.pk, + "result_data": self._create_error_result(image_id=str(self.images[0].pk)), + "reply_subject": "reply.concurrent.1", + } + ) + + # Worker 2 processes images[1] — no retry, no lock to wait for + result_2 = process_nats_pipeline_result.apply( + kwargs={ + "job_id": self.job.pk, + "result_data": self._create_error_result(image_id=str(self.images[1].pk)), + "reply_subject": "reply.concurrent.2", + } + ) + + self.assertTrue(result_1.successful()) + self.assertTrue(result_2.successful()) - # Assert: Progress was NOT updated (lock not acquired) + # Both images should be marked as processed manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") - self.assertEqual(progress.processed, 0) + self.assertIsNotNone(progress) + self.assertEqual(progress.processed, 2) + self.assertEqual(progress.total, 3) + self.assertEqual(mock_manager.acknowledge_task.call_count, 2) @patch("ami.jobs.tasks.TaskQueueManager") def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_class): diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py index 5a300c12a..e626c5948 100644 --- a/ami/ml/orchestration/async_job_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -2,8 +2,15 @@ Internal progress tracking for async (NATS) job processing, backed by Redis. Multiple Celery workers process image batches concurrently and report progress -here using Redis for atomic updates with locking. This module is purely internal -— nothing outside the worker pipeline reads from it directly. +here using Redis native set operations. No locking is required because: + + - SREM (remove processed images from pending set) is atomic per call + - SADD (add to failed set) is atomic per call + - SCARD (read set size) is O(1) without deserializing members + +Workers update state independently via a single Redis pipeline round-trip. +This module is purely internal — nothing outside the worker pipeline reads +from it directly. How this relates to the Job model (ami/jobs/models.py): @@ -27,7 +34,7 @@ import logging from dataclasses import dataclass -from django.core.cache import cache +from django_redis import get_redis_connection logger = logging.getLogger(__name__) @@ -50,17 +57,13 @@ class JobStateProgress: failed: int = 0 # source images that returned an error from the processing service -def _lock_key(job_id: int) -> str: - return f"job:{job_id}:process_results_lock" - - class AsyncJobStateManager: """ Manages real-time job progress in Redis for concurrent NATS workers. - Each job has per-stage pending image lists and a shared failed image set. - Workers acquire a Redis lock before mutating state, ensuring atomic updates - even when multiple Celery tasks process batches in parallel. + Each job has per-stage pending image sets and a shared failed image set, + all stored as native Redis sets. Workers update state via atomic SREM/SADD + commands — no locking needed. The results are ephemeral — _update_job_progress() in ami/jobs/tasks.py copies each snapshot into the persistent Job.progress JSONB field. @@ -70,17 +73,14 @@ class AsyncJobStateManager: STAGES = ["process", "results"] def __init__(self, job_id: int): - """ - Initialize the task state manager for a specific job. - - Args: - job_id: The job primary key - """ self.job_id = job_id self._pending_key = f"job:{job_id}:pending_images" self._total_key = f"job:{job_id}:pending_images_total" self._failed_key = f"job:{job_id}:failed_images" + def _get_redis(self): + return get_redis_connection("default") + def initialize_job(self, image_ids: list[str]) -> None: """ Initialize job tracking with a list of image IDs to process. @@ -88,13 +88,17 @@ def initialize_job(self, image_ids: list[str]) -> None: Args: image_ids: List of image IDs that need to be processed """ - for stage in self.STAGES: - cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) - - # Initialize failed images set for process stage only - cache.set(self._failed_key, set(), timeout=self.TIMEOUT) - - cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + redis = self._get_redis() + with redis.pipeline() as pipe: + for stage in self.STAGES: + pending_key = self._get_pending_key(stage) + pipe.delete(pending_key) + if image_ids: + pipe.sadd(pending_key, *image_ids) + pipe.expire(pending_key, self.TIMEOUT) + pipe.delete(self._failed_key) + pipe.set(self._total_key, len(image_ids), ex=self.TIMEOUT) + pipe.execute() def _get_pending_key(self, stage: str) -> str: return f"{self._pending_key}:{stage}" @@ -103,100 +107,81 @@ def update_state( self, processed_image_ids: set[str], stage: str, - request_id: str, failed_image_ids: set[str] | None = None, - ) -> None | JobStateProgress: + ) -> "JobStateProgress | None": """ - Update the task state with newly processed images. + Atomically update job state with newly processed images. + + Uses a Redis pipeline (single round-trip). SREM and SADD are each + individually atomic; the pipeline batches them with SCARD/GET to avoid + multiple round-trips. Workers can call this concurrently — no lock needed. Args: processed_image_ids: Set of image IDs that have just been processed stage: The processing stage ("process" or "results") - request_id: Unique identifier for this processing request - detections_count: Number of detections to add to cumulative count - classifications_count: Number of classifications to add to cumulative count - captures_count: Number of captures to add to cumulative count failed_image_ids: Set of image IDs that failed processing (optional) + + Returns: + JobStateProgress snapshot, or None if Redis state is missing + (job expired or not yet initialized). """ - # Create a unique lock key for this job - lock_key = _lock_key(self.job_id) - lock_timeout = 360 # 6 minutes (matches task time_limit) - lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) - if not lock_acquired: + redis = self._get_redis() + pending_key = self._get_pending_key(stage) + + with redis.pipeline() as pipe: + if processed_image_ids: + pipe.srem(pending_key, *processed_image_ids) + if failed_image_ids: + pipe.sadd(self._failed_key, *failed_image_ids) + pipe.scard(pending_key) + pipe.scard(self._failed_key) + pipe.get(self._total_key) + results = pipe.execute() + + # Last 3 results are always scard(pending), scard(failed), get(total) + # regardless of whether SREM/SADD appear at the front. + remaining, failed_count, total_raw = results[-3], results[-2], results[-1] + + if total_raw is None: return None - try: - # Update progress tracking in Redis - progress_info = self._commit_update(processed_image_ids, stage, failed_image_ids) - return progress_info - finally: - # Always release the lock when done - current_lock_value = cache.get(lock_key) - # Only delete if we still own the lock (prevents race condition) - if current_lock_value == request_id: - cache.delete(lock_key) - logger.debug(f"Released lock for job {self.job_id}, task {request_id}") - - def get_progress(self, stage: str) -> JobStateProgress | None: - """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" - pending_images = cache.get(self._get_pending_key(stage)) - total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: - return None - remaining = len(pending_images) - processed = total_images - remaining - percentage = float(processed) / total_images if total_images > 0 else 1.0 - failed_set = cache.get(self._failed_key) or set() + total = int(total_raw) + processed = total - remaining + percentage = float(processed) / total if total > 0 else 1.0 + + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " f"{remaining}/{total}: {percentage*100}%" + ) + return JobStateProgress( remaining=remaining, - total=total_images, + total=total, processed=processed, percentage=percentage, - failed=len(failed_set), + failed=failed_count, ) - def _commit_update( - self, - processed_image_ids: set[str], - stage: str, - failed_image_ids: set[str] | None = None, - ) -> JobStateProgress | None: - """ - Update pending images and return progress. Must be called under lock. + def get_progress(self, stage: str) -> "JobStateProgress | None": + """Read-only progress snapshot for the given stage.""" + redis = self._get_redis() + pending_key = self._get_pending_key(stage) - Removes processed_image_ids from the pending set and persists the update. - """ - pending_images = cache.get(self._get_pending_key(stage)) - total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: - return None - remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] - assert len(pending_images) >= len(remaining_images) - cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) - - remaining = len(remaining_images) - processed = total_images - remaining - percentage = float(processed) / total_images if total_images > 0 else 1.0 - - # Update failed images set if provided - if failed_image_ids: - existing_failed = cache.get(self._failed_key) or set() - updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates - cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT) - failed_set = updated_failed - else: - failed_set = cache.get(self._failed_key) or set() + with redis.pipeline() as pipe: + pipe.scard(pending_key) + pipe.scard(self._failed_key) + pipe.get(self._total_key) + remaining, failed_count, total_raw = pipe.execute() - failed_count = len(failed_set) + if total_raw is None: + return None - logger.info( - f"Pending images from Redis for job {self.job_id} {stage}: " - f"{remaining}/{total_images}: {percentage*100}%" - ) + total = int(total_raw) + processed = total - remaining + percentage = float(processed) / total if total > 0 else 1.0 return JobStateProgress( remaining=remaining, - total=total_images, + total=total, processed=processed, percentage=percentage, failed=failed_count, @@ -206,7 +191,7 @@ def cleanup(self) -> None: """ Delete all Redis keys associated with this job. """ - for stage in self.STAGES: - cache.delete(self._get_pending_key(stage)) - cache.delete(self._failed_key) - cache.delete(self._total_key) + redis = self._get_redis() + keys = [self._get_pending_key(stage) for stage in self.STAGES] + keys += [self._failed_key, self._total_key] + redis.delete(*keys) diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index ccdfa2c49..cb626348a 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -1,7 +1,6 @@ """Integration tests for async job resource cleanup (NATS and Redis).""" from asgiref.sync import async_to_sync -from django.core.cache import cache from django.test import TestCase from nats.js.errors import NotFoundError @@ -58,13 +57,10 @@ def _verify_resources_created(self, job_id: int): Args: job_id: The job ID to check """ - # Verify Redis keys exist + # Verify Redis state exists (get_progress returns non-None when total_key is set) state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: - pending_key = state_manager._get_pending_key(stage) - self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist") - total_key = state_manager._total_key - self.assertIsNotNone(cache.get(total_key), f"Redis key {total_key} should exist") + self.assertIsNotNone(state_manager.get_progress(stage), f"Redis state for stage '{stage}' should exist") # Verify NATS stream and consumer exist async def check_nats_resources(): @@ -124,13 +120,10 @@ def _verify_resources_cleaned(self, job_id: int): Args: job_id: The job ID to check """ - # Verify Redis keys are deleted + # Verify Redis state is deleted (get_progress returns None when total_key is gone) state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: - pending_key = state_manager._get_pending_key(stage) - self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted") - total_key = state_manager._total_key - self.assertIsNone(cache.get(total_key), f"Redis key {total_key} should be deleted") + self.assertIsNone(state_manager.get_progress(stage), f"Redis state for stage '{stage}' should be deleted") # Verify NATS stream and consumer are deleted async def check_nats_resources(): diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 6d029492b..6a34040a8 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -919,7 +919,7 @@ def setUp(self): def _init_and_verify(self, image_ids): """Helper to initialize job and verify initial state.""" self.manager.initialize_job(image_ids) - progress = self.manager._commit_update(set(), "process") + progress = self.manager.get_progress("process") assert progress is not None self.assertEqual(progress.total, len(image_ids)) self.assertEqual(progress.remaining, len(image_ids)) @@ -934,7 +934,7 @@ def test_initialize_job(self): # Verify both stages are initialized for stage in self.manager.STAGES: - progress = self.manager._commit_update(set(), stage) + progress = self.manager.get_progress(stage) assert progress is not None self.assertEqual(progress.total, len(self.image_ids)) self.assertEqual(progress.failed, 0) @@ -944,70 +944,70 @@ def test_progress_tracking(self): self._init_and_verify(self.image_ids) # Process 2 images - progress = self.manager._commit_update({"img1", "img2"}, "process") + progress = self.manager.update_state({"img1", "img2"}, "process") assert progress is not None self.assertEqual(progress.remaining, 3) self.assertEqual(progress.processed, 2) self.assertEqual(progress.percentage, 0.4) # Process 2 more images - progress = self.manager._commit_update({"img3", "img4"}, "process") + progress = self.manager.update_state({"img3", "img4"}, "process") assert progress is not None self.assertEqual(progress.remaining, 1) self.assertEqual(progress.processed, 4) self.assertEqual(progress.percentage, 0.8) # Process last image - progress = self.manager._commit_update({"img5"}, "process") + progress = self.manager.update_state({"img5"}, "process") assert progress is not None self.assertEqual(progress.remaining, 0) self.assertEqual(progress.processed, 5) self.assertEqual(progress.percentage, 1.0) - def test_update_state_with_locking(self): - """Test update_state acquires lock, updates progress, and releases lock.""" - from django.core.cache import cache - + def test_update_state_concurrent(self): + """Test that concurrent workers update state independently without blocking.""" self._init_and_verify(self.image_ids) - # First update should succeed - progress = self.manager.update_state({"img1", "img2"}, "process", "task1") - assert progress is not None - self.assertEqual(progress.processed, 2) + # Worker 1 processes img1, img2 — succeeds immediately + progress_1 = self.manager.update_state({"img1", "img2"}, "process") + assert progress_1 is not None + self.assertEqual(progress_1.processed, 2) - # Simulate concurrent update by holding the lock - lock_key = f"job:{self.job_id}:process_results_lock" - cache.set(lock_key, "other_task", timeout=60) + # Worker 2 processes img3 — no lock to wait for, also succeeds immediately + progress_2 = self.manager.update_state({"img3"}, "process") + assert progress_2 is not None + self.assertEqual(progress_2.processed, 3) - # Update should fail (lock held by another task) - progress = self.manager.update_state({"img3"}, "process", "task1") - self.assertIsNone(progress) + # Final state reflects both updates + final = self.manager.get_progress("process") + assert final is not None + self.assertEqual(final.processed, 3) + self.assertEqual(final.remaining, 2) - # Release the lock and retry - cache.delete(lock_key) - progress = self.manager.update_state({"img3"}, "process", "task1") - assert progress is not None - self.assertEqual(progress.processed, 3) + # SREM is idempotent: retrying already-processed images doesn't change counts + progress_retry = self.manager.update_state({"img1", "img2"}, "process") + assert progress_retry is not None + self.assertEqual(progress_retry.processed, 3) def test_stages_independent(self): """Test that different stages track progress independently.""" self._init_and_verify(self.image_ids) # Update process stage - self.manager._commit_update({"img1", "img2"}, "process") - progress_process = self.manager._commit_update(set(), "process") + self.manager.update_state({"img1", "img2"}, "process") + progress_process = self.manager.get_progress("process") assert progress_process is not None self.assertEqual(progress_process.remaining, 3) # Results stage should still have all images pending - progress_results = self.manager._commit_update(set(), "results") + progress_results = self.manager.get_progress("results") assert progress_results is not None self.assertEqual(progress_results.remaining, 5) def test_empty_job(self): """Test handling of job with no images.""" self.manager.initialize_job([]) - progress = self.manager._commit_update(set(), "process") + progress = self.manager.get_progress("process") assert progress is not None self.assertEqual(progress.total, 0) self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete @@ -1017,14 +1017,14 @@ def test_cleanup(self): self._init_and_verify(self.image_ids) # Verify keys exist - progress = self.manager._commit_update(set(), "process") + progress = self.manager.get_progress("process") self.assertIsNotNone(progress) # Cleanup self.manager.cleanup() # Verify keys are gone - progress = self.manager._commit_update(set(), "process") + progress = self.manager.get_progress("process") self.assertIsNone(progress) def test_failed_image_tracking(self): @@ -1032,17 +1032,17 @@ def test_failed_image_tracking(self): self._init_and_verify(self.image_ids) # Mark 2 images as failed in process stage - progress = self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + progress = self.manager.update_state({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) assert progress is not None self.assertEqual(progress.failed, 2) - # Retry same 2 images (fail again) - should not double-count - progress = self.manager._commit_update(set(), "process", failed_image_ids={"img1", "img2"}) + # Retry same 2 images (fail again) - SADD is idempotent, no double-counting + progress = self.manager.update_state(set(), "process", failed_image_ids={"img1", "img2"}) assert progress is not None self.assertEqual(progress.failed, 2) # Fail a different image - progress = self.manager._commit_update(set(), "process", failed_image_ids={"img3"}) + progress = self.manager.update_state(set(), "process", failed_image_ids={"img3"}) assert progress is not None self.assertEqual(progress.failed, 3) @@ -1051,7 +1051,7 @@ def test_failed_and_processed_mixed(self): self._init_and_verify(self.image_ids) # Process 2 successfully, 2 fail, 1 remains pending - progress = self.manager._commit_update( + progress = self.manager.update_state( {"img1", "img2", "img3", "img4"}, "process", failed_image_ids={"img3", "img4"} ) assert progress is not None @@ -1062,20 +1062,16 @@ def test_failed_and_processed_mixed(self): def test_cleanup_removes_failed_set(self): """Test that cleanup removes failed image set.""" - from django.core.cache import cache - self._init_and_verify(self.image_ids) - # Add failed images - self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) - - # Verify failed set exists - failed_set = cache.get(self.manager._failed_key) - self.assertEqual(len(failed_set), 2) + # Add failed images and verify they're tracked + progress = self.manager.update_state({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) # Cleanup self.manager.cleanup() - # Verify failed set is gone - failed_set = cache.get(self.manager._failed_key) - self.assertIsNone(failed_set) + # Verify all state is gone (get_progress returns None when total_key is deleted) + progress = self.manager.get_progress("process") + self.assertIsNone(progress) From 30c8db349d80046ad3a32775020a33c2f40e8947 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 24 Feb 2026 16:20:13 -0800 Subject: [PATCH 06/30] Test concurrency --- ami/ml/tests.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 6a34040a8..353dcfdf1 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1,3 +1,4 @@ +import concurrent.futures import datetime import pathlib import unittest @@ -965,29 +966,35 @@ def test_progress_tracking(self): self.assertEqual(progress.percentage, 1.0) def test_update_state_concurrent(self): - """Test that concurrent workers update state independently without blocking.""" + """Test that concurrent workers update state correctly without data races.""" self._init_and_verify(self.image_ids) - # Worker 1 processes img1, img2 — succeeds immediately - progress_1 = self.manager.update_state({"img1", "img2"}, "process") - assert progress_1 is not None - self.assertEqual(progress_1.processed, 2) + # Three workers process disjoint image sets truly concurrently + errors: list[Exception] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [ + executor.submit(self.manager.update_state, {"img1", "img2"}, "process"), + executor.submit(self.manager.update_state, {"img3"}, "process"), + executor.submit(self.manager.update_state, {"img4", "img5"}, "process"), + ] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + errors.append(e) - # Worker 2 processes img3 — no lock to wait for, also succeeds immediately - progress_2 = self.manager.update_state({"img3"}, "process") - assert progress_2 is not None - self.assertEqual(progress_2.processed, 3) + self.assertEqual(errors, [], f"Concurrent workers raised exceptions: {errors}") - # Final state reflects both updates + # Final state reflects all concurrent updates final = self.manager.get_progress("process") assert final is not None - self.assertEqual(final.processed, 3) - self.assertEqual(final.remaining, 2) + self.assertEqual(final.processed, 5) + self.assertEqual(final.remaining, 0) # SREM is idempotent: retrying already-processed images doesn't change counts progress_retry = self.manager.update_state({"img1", "img2"}, "process") assert progress_retry is not None - self.assertEqual(progress_retry.processed, 3) + self.assertEqual(progress_retry.processed, 5) def test_stages_independent(self): """Test that different stages track progress independently.""" From deea095d6e6e48036b6abc6b68bb79c445bd2d25 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 09:22:20 -0800 Subject: [PATCH 07/30] Increase max ack pending --- ami/ml/orchestration/nats_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 489a0dc66..e55440fd0 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -64,7 +64,7 @@ class TaskQueueManager: def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = None): self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") self.max_ack_pending = ( - max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 100) + max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 1000) ) self.nc: nats.NATS | None = None self.js: JetStreamContext | None = None From 20c0fbd29b2d970dc70dde292c3b778fa63c8e8a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 09:23:20 -0800 Subject: [PATCH 08/30] update comment --- ami/ml/orchestration/nats_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index e55440fd0..a23d28ac8 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -52,7 +52,7 @@ class TaskQueueManager: Args: nats_url: NATS server URL. Falls back to settings.NATS_URL, then "nats://nats:4222". max_ack_pending: Max unacknowledged messages per consumer. Falls back to - settings.NATS_MAX_ACK_PENDING, then 100. + settings.NATS_MAX_ACK_PENDING, then 1000. Use as an async context manager: async with TaskQueueManager() as manager: From e84421eb88f32a8027c8f2e2da685f0dcf7988f8 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 13:43:33 -0800 Subject: [PATCH 09/30] CR feedback --- ami/jobs/tasks.py | 7 +++++++ ami/jobs/test_tasks.py | 40 ++++++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 6d385ba5d..92a5680a4 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -87,6 +87,9 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids) if not progress_info: logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + # Acknowledge the task to prevent retries, since we don't know the state + _ack_task_via_nats(reply_subject, logger) + # TODO: cancel the job to fail fast once PR #1144 is merged return try: @@ -149,6 +152,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub if not progress_info: logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + # TODO: cancel the job to fail fast once PR #1144 is merged return # update complete state based on latest progress info after saving results @@ -253,6 +257,9 @@ def _update_job_progress( try: existing_stage = job.progress.get_stage(stage) progress_percentage = max(existing_stage.progress, progress_percentage) + # JobState is ordered with FAILURE < SUCCESS, so max() will keep it at FAILURE + # if any worker reported failure + complete_state = max(existing_stage.status, complete_state) except (ValueError, AttributeError): pass # Stage doesn't exist yet; proceed normally diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index 1a86f2e95..7d5cb25aa 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -6,6 +6,7 @@ """ import logging +from concurrent.futures import ThreadPoolExecutor from unittest.mock import AsyncMock, MagicMock, patch from django.core.cache import cache @@ -246,26 +247,29 @@ def test_process_nats_pipeline_result_concurrent_updates(self, mock_manager_clas """ mock_manager = self._setup_mock_nats(mock_manager_class) - # Worker 1 processes images[0] - result_1 = process_nats_pipeline_result.apply( - kwargs={ - "job_id": self.job.pk, - "result_data": self._create_error_result(image_id=str(self.images[0].pk)), - "reply_subject": "reply.concurrent.1", - } - ) + with ThreadPoolExecutor(max_workers=2) as executor: + # Worker 1 processes images[0] + result_1 = executor.submit( + process_nats_pipeline_result.apply, + kwargs={ + "job_id": self.job.pk, + "result_data": self._create_error_result(image_id=str(self.images[0].pk)), + "reply_subject": "reply.concurrent.1", + }, + ) - # Worker 2 processes images[1] — no retry, no lock to wait for - result_2 = process_nats_pipeline_result.apply( - kwargs={ - "job_id": self.job.pk, - "result_data": self._create_error_result(image_id=str(self.images[1].pk)), - "reply_subject": "reply.concurrent.2", - } - ) + # Worker 2 processes images[1] — no retry, no lock to wait for + result_2 = executor.submit( + process_nats_pipeline_result.apply, + kwargs={ + "job_id": self.job.pk, + "result_data": self._create_error_result(image_id=str(self.images[1].pk)), + "reply_subject": "reply.concurrent.2", + }, + ) - self.assertTrue(result_1.successful()) - self.assertTrue(result_2.successful()) + self.assertTrue(result_1.result().successful()) + self.assertTrue(result_2.result().successful()) # Both images should be marked as processed manager = AsyncJobStateManager(self.job.pk) From cbb2d7fef2072e52d0853e1fbce9b6f80793c7be Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 13:53:07 -0800 Subject: [PATCH 10/30] Cancel jobs if Redis state is missing --- ami/jobs/tasks.py | 20 +++++++++++++++----- ami/jobs/views.py | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 92a5680a4..4c4547120 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -86,10 +86,9 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids) if not progress_info: - logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") - # Acknowledge the task to prevent retries, since we don't know the state + # Acknowledge the task to prevent retries _ack_task_via_nats(reply_subject, logger) - # TODO: cancel the job to fail fast once PR #1144 is merged + _cancel_job_on_missing_state(job_id, logger) return try: @@ -151,8 +150,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) if not progress_info: - logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") - # TODO: cancel the job to fail fast once PR #1144 is merged + _cancel_job_on_missing_state(job_id, logger) return # update complete state based on latest progress info after saving results @@ -176,6 +174,18 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) +def _cancel_job_on_missing_state(job_id: int, logger: logging.Logger) -> None: + from ami.jobs.models import Job, JobState + + logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + + # cancel job (fail fast) since we don't know the state + job = Job.objects.get(pk=job_id) + if job.status != JobState.CANCELING and job.status not in JobState.final_states(): + job.logger.error(f"Job {job_id} is not canceling or finished, but Redis state is missing. ") + job.cancel() + + def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: try: diff --git a/ami/jobs/views.py b/ami/jobs/views.py index ddc1e57a7..cdd1886e7 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -238,7 +238,7 @@ def tasks(self, request, pk=None): raise ValidationError("Only async_api jobs have fetchable tasks") # Don't fetch tasks from completed/failed/revoked jobs - if job.status in JobState.final_states(): + if job.status in JobState.final_states() or job.status == JobState.CANCELING: return Response({"tasks": []}) # Validate that the job has a pipeline From 38611909579a5d4da857d18011b6bc936fd24277 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 13:55:44 -0800 Subject: [PATCH 11/30] Add chaos monkey --- ami/jobs/management/commands/chaos_monkey.py | 146 +++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 ami/jobs/management/commands/chaos_monkey.py diff --git a/ami/jobs/management/commands/chaos_monkey.py b/ami/jobs/management/commands/chaos_monkey.py new file mode 100644 index 000000000..f1c2793b5 --- /dev/null +++ b/ami/jobs/management/commands/chaos_monkey.py @@ -0,0 +1,146 @@ +""" +Fault injection utility for manual chaos testing of ML async jobs. + +Use alongside `test_ml_job_e2e` to verify job behaviour when Redis or NATS +becomes unavailable or loses state mid-processing. + +Usage examples: + + # Flush all Redis state immediately (simulates FLUSHDB mid-job) + python manage.py chaos_monkey flush redis + + # Flush all NATS JetStream streams (simulates broker state loss) + python manage.py chaos_monkey flush nats + + # Pause Redis for 15 seconds then restore (simulates transient outage) + python manage.py chaos_monkey pause redis + + # Pause NATS for 30 seconds then restore + python manage.py chaos_monkey pause nats --duration 30 +""" + +import subprocess +import time + +from asgiref.sync import async_to_sync +from django.core.management.base import BaseCommand, CommandError +from django_redis import get_redis_connection + +REDIS_CONTAINER = "ami_local_redis" +NATS_CONTAINER = "ami_local_nats" +NATS_URL = "nats://ami_local_nats:4222" + + +class Command(BaseCommand): + help = "Inject faults into Redis or NATS for chaos/resilience testing" + + def add_arguments(self, parser): + parser.add_argument( + "action", + choices=["flush", "pause"], + help="flush: wipe all state. pause: stop the service temporarily then restore it.", + ) + parser.add_argument( + "service", + choices=["redis", "nats"], + help="Target service to fault.", + ) + parser.add_argument( + "--duration", + type=int, + default=15, + metavar="SECONDS", + help="How long to keep the service paused before restoring (pause only, default: 15).", + ) + + def handle(self, *args, **options): + action = options["action"] + service = options["service"] + duration = options["duration"] + + if action == "flush" and service == "redis": + self._flush_redis() + elif action == "flush" and service == "nats": + self._flush_nats() + elif action == "pause" and service == "redis": + self._pause_container(REDIS_CONTAINER, duration) + elif action == "pause" and service == "nats": + self._pause_container(NATS_CONTAINER, duration) + + # ------------------------------------------------------------------ + # Redis + # ------------------------------------------------------------------ + + def _flush_redis(self): + self.stdout.write("Flushing Redis database (FLUSHDB)...") + try: + redis = get_redis_connection("default") + redis.flushdb() + self.stdout.write(self.style.SUCCESS("Redis flushed.")) + except Exception as e: + raise CommandError(f"Failed to flush Redis: {e}") from e + + # ------------------------------------------------------------------ + # NATS + # ------------------------------------------------------------------ + + def _flush_nats(self): + """Delete all JetStream streams via the NATS Python client.""" + self.stdout.write("Flushing all NATS JetStream streams...") + + async def _delete_all_streams(): + import nats + + nc = await nats.connect(NATS_URL, connect_timeout=5, allow_reconnect=False) + js = nc.jetstream() + try: + streams = await js.streams_info() + if not streams: + return [] + deleted = [] + for stream in streams: + name = stream.config.name + await js.delete_stream(name) + deleted.append(name) + return deleted + finally: + await nc.close() + + try: + deleted = async_to_sync(_delete_all_streams)() + except Exception as e: + raise CommandError(f"Failed to flush NATS: {e}") from e + + if deleted: + for name in deleted: + self.stdout.write(f" Deleted stream: {name}") + self.stdout.write(self.style.SUCCESS(f"Deleted {len(deleted)} stream(s).")) + else: + self.stdout.write("No streams found — NATS already empty.") + + # ------------------------------------------------------------------ + # Container pause/unpause (works for both redis and nats) + # ------------------------------------------------------------------ + + def _pause_container(self, container: str, duration: int): + self.stdout.write(f"Pausing container '{container}' for {duration}s...") + self._docker("pause", container) + self.stdout.write(self.style.WARNING(f"Container paused. Waiting {duration}s...")) + + for remaining in range(duration, 0, -1): + self.stdout.write(f"\r {remaining}s remaining...", ending="") + self.stdout.flush() + time.sleep(1) + + self.stdout.write("") # newline after countdown + self._docker("unpause", container) + self.stdout.write(self.style.SUCCESS(f"Container '{container}' restored.")) + + def _docker(self, subcommand: str, container: str): + result = subprocess.run( + ["docker", subcommand, container], + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise CommandError(f"`docker {subcommand} {container}` failed:\n{result.stderr.strip()}") From d591bd63387606c39709ab0f8aee3e7ebc0b017c Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 15:38:11 -0800 Subject: [PATCH 12/30] CR feedback --- ami/jobs/tasks.py | 3 +- ami/jobs/test_tasks.py | 4 +- ami/ml/orchestration/async_job_state.py | 87 +++++++++++++++---------- ami/ml/tests.py | 9 +-- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 92a5680a4..074b05c87 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -257,7 +257,7 @@ def _update_job_progress( try: existing_stage = job.progress.get_stage(stage) progress_percentage = max(existing_stage.progress, progress_percentage) - # JobState is ordered with FAILURE < SUCCESS, so max() will keep it at FAILURE + # JobState is ordered with FAILURE > SUCCESS, so max() will keep it at FAILURE # if any worker reported failure complete_state = max(existing_stage.status, complete_state) except (ValueError, AttributeError): @@ -265,6 +265,7 @@ def _update_job_progress( job.progress.update_stage( stage, + # always use STARTED for in-progress updates status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, **state_params, diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index 7d5cb25aa..25e609244 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from django.core.cache import cache -from django.test import TestCase +from django.test import TransactionTestCase from rest_framework.test import APITestCase from ami.base.serializers import reverse_with_params @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -class TestProcessNatsPipelineResultError(TestCase): +class TestProcessNatsPipelineResultError(TransactionTestCase): """E2E tests for process_nats_pipeline_result with error handling.""" def setUp(self): diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py index e626c5948..1f3f2e371 100644 --- a/ami/ml/orchestration/async_job_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -35,6 +35,7 @@ from dataclasses import dataclass from django_redis import get_redis_connection +from redis.exceptions import RedisError logger = logging.getLogger(__name__) @@ -88,17 +89,21 @@ def initialize_job(self, image_ids: list[str]) -> None: Args: image_ids: List of image IDs that need to be processed """ - redis = self._get_redis() - with redis.pipeline() as pipe: - for stage in self.STAGES: - pending_key = self._get_pending_key(stage) - pipe.delete(pending_key) - if image_ids: - pipe.sadd(pending_key, *image_ids) - pipe.expire(pending_key, self.TIMEOUT) - pipe.delete(self._failed_key) - pipe.set(self._total_key, len(image_ids), ex=self.TIMEOUT) - pipe.execute() + try: + redis = self._get_redis() + with redis.pipeline() as pipe: + for stage in self.STAGES: + pending_key = self._get_pending_key(stage) + pipe.delete(pending_key) + if image_ids: + pipe.sadd(pending_key, *image_ids) + pipe.expire(pending_key, self.TIMEOUT) + pipe.delete(self._failed_key) + pipe.set(self._total_key, len(image_ids), ex=self.TIMEOUT) + pipe.execute() + except RedisError as e: + logger.error(f"Redis error initializing job {self.job_id}: {e}") + raise def _get_pending_key(self, stage: str) -> str: return f"{self._pending_key}:{stage}" @@ -125,18 +130,23 @@ def update_state( JobStateProgress snapshot, or None if Redis state is missing (job expired or not yet initialized). """ - redis = self._get_redis() - pending_key = self._get_pending_key(stage) - - with redis.pipeline() as pipe: - if processed_image_ids: - pipe.srem(pending_key, *processed_image_ids) - if failed_image_ids: - pipe.sadd(self._failed_key, *failed_image_ids) - pipe.scard(pending_key) - pipe.scard(self._failed_key) - pipe.get(self._total_key) - results = pipe.execute() + try: + redis = self._get_redis() + pending_key = self._get_pending_key(stage) + + with redis.pipeline() as pipe: + if processed_image_ids: + pipe.srem(pending_key, *processed_image_ids) + if failed_image_ids: + pipe.sadd(self._failed_key, *failed_image_ids) + pipe.expire(self._failed_key, self.TIMEOUT) + pipe.scard(pending_key) + pipe.scard(self._failed_key) + pipe.get(self._total_key) + results = pipe.execute() + except RedisError as e: + logger.error(f"Redis error updating job {self.job_id} state: {e}") + return None # Last 3 results are always scard(pending), scard(failed), get(total) # regardless of whether SREM/SADD appear at the front. @@ -163,14 +173,18 @@ def update_state( def get_progress(self, stage: str) -> "JobStateProgress | None": """Read-only progress snapshot for the given stage.""" - redis = self._get_redis() - pending_key = self._get_pending_key(stage) - - with redis.pipeline() as pipe: - pipe.scard(pending_key) - pipe.scard(self._failed_key) - pipe.get(self._total_key) - remaining, failed_count, total_raw = pipe.execute() + try: + redis = self._get_redis() + pending_key = self._get_pending_key(stage) + + with redis.pipeline() as pipe: + pipe.scard(pending_key) + pipe.scard(self._failed_key) + pipe.get(self._total_key) + remaining, failed_count, total_raw = pipe.execute() + except RedisError as e: + logger.error(f"Redis error reading job {self.job_id} progress: {e}") + return None if total_raw is None: return None @@ -191,7 +205,10 @@ def cleanup(self) -> None: """ Delete all Redis keys associated with this job. """ - redis = self._get_redis() - keys = [self._get_pending_key(stage) for stage in self.STAGES] - keys += [self._failed_key, self._total_key] - redis.delete(*keys) + try: + redis = self._get_redis() + keys = [self._get_pending_key(stage) for stage in self.STAGES] + keys += [self._failed_key, self._total_key] + redis.delete(*keys) + except RedisError as e: + logger.warning(f"Redis error cleaning up job {self.job_id}: {e}") diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 353dcfdf1..f5fb78867 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -970,18 +970,15 @@ def test_update_state_concurrent(self): self._init_and_verify(self.image_ids) # Three workers process disjoint image sets truly concurrently - errors: list[Exception] = [] + errors: list[BaseException] = [] with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: futures = [ executor.submit(self.manager.update_state, {"img1", "img2"}, "process"), executor.submit(self.manager.update_state, {"img3"}, "process"), executor.submit(self.manager.update_state, {"img4", "img5"}, "process"), ] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - errors.append(e) + _errors = [f.exception() for f in concurrent.futures.as_completed(futures)] + errors = [e for e in _errors if e is not None] self.assertEqual(errors, [], f"Concurrent workers raised exceptions: {errors}") From 4720bb6647508888e813e7cd4f849f60acb73957 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 16:02:00 -0800 Subject: [PATCH 13/30] CR 2 --- ami/jobs/tasks.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 074b05c87..2d4f9e8b9 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -124,6 +124,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub _ack_task_via_nats(reply_subject, logger) return + acked = False try: # Save to database (this is the slow operation) detections_count, classifications_count, captures_count = 0, 0, 0 @@ -143,6 +144,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub captures_count = len(pipeline_result.source_images) _ack_task_via_nats(reply_subject, job.logger) + acked = True # Update job stage with calculated progress progress_info = state_manager.update_state( @@ -171,9 +173,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) except Exception as e: - job.logger.error( - f"Failed to process pipeline result for job {job_id}: {e}. NATS will redeliver the task message." - ) + error = f"Error processing pipeline result for job {job_id}: {e}" + if not acked: + error += ". NATS will re-deliver the task message." + + job.logger.error(error) def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: @@ -253,20 +257,31 @@ def _update_job_progress( # Don't overwrite a stage with a stale progress value. # This guards against the race where a slower worker calls _update_job_progress - # after a faster worker has already marked further progress + # after a faster worker has already marked further progress. try: existing_stage = job.progress.get_stage(stage) progress_percentage = max(existing_stage.progress, progress_percentage) - # JobState is ordered with FAILURE > SUCCESS, so max() will keep it at FAILURE - # if any worker reported failure - complete_state = max(existing_stage.status, complete_state) + # Explicitly preserve FAILURE: once a stage is marked FAILURE it should + # never regress to a non-failure state, regardless of enum ordering. + if existing_stage.status == JobState.FAILURE: + complete_state = JobState.FAILURE except (ValueError, AttributeError): pass # Stage doesn't exist yet; proceed normally + # Determine the status to write: + # - Stage complete (100%): use complete_state (SUCCESS or FAILURE) + # - Stage incomplete but FAILURE already determined: keep FAILURE visible + # - Stage incomplete, no failure: mark as in-progress (STARTED) + if progress_percentage >= 1.0: + status = complete_state + elif complete_state == JobState.FAILURE: + status = JobState.FAILURE + else: + status = JobState.STARTED + job.progress.update_stage( stage, - # always use STARTED for in-progress updates - status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, + status=status, progress=progress_percentage, **state_params, ) From f0cd403475db3f0671877d79df028198e02b869a Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 25 Feb 2026 17:15:06 -0800 Subject: [PATCH 14/30] fix: OrderedEnum comparisons now override str MRO in subclasses JobState(str, OrderedEnum) was using str's lexicographic __gt__ instead of OrderedEnum's definition-order __gt__, because str comes first in the MRO. This caused max(FAILURE, SUCCESS) to return SUCCESS, silently discarding failure state in concurrent job progress updates. Fix: __init_subclass__ injects comparison methods directly onto each subclass so they take MRO priority over data-type mixins. Also preserve FAILURE status through the progress ternary when progress < 1.0, so early failure detection isn't overwritten. Co-Authored-By: Claude --- ami/jobs/tasks.py | 9 +++++---- ami/utils/schemas.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 074b05c87..4fae09a65 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -257,16 +257,17 @@ def _update_job_progress( try: existing_stage = job.progress.get_stage(stage) progress_percentage = max(existing_stage.progress, progress_percentage) - # JobState is ordered with FAILURE > SUCCESS, so max() will keep it at FAILURE - # if any worker reported failure + # FAILURE is defined after SUCCESS in JobState, so max() preserves FAILURE + # if any worker reported failure (updated OrderedEnum so max() works for this usage) complete_state = max(existing_stage.status, complete_state) except (ValueError, AttributeError): pass # Stage doesn't exist yet; proceed normally job.progress.update_stage( stage, - # always use STARTED for in-progress updates - status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, + status=complete_state + if progress_percentage >= 1.0 or complete_state == JobState.FAILURE + else JobState.STARTED, progress=progress_percentage, **state_params, ) diff --git a/ami/utils/schemas.py b/ami/utils/schemas.py index 0bf69ea70..fa8946375 100644 --- a/ami/utils/schemas.py +++ b/ami/utils/schemas.py @@ -12,13 +12,42 @@ class OrderedEnum(Enum): This also implements a case-insensitive lookup for values. + Comparison methods are injected onto subclasses via __init_subclass__ so that + definition-order comparisons take MRO priority over data-type mixins (str, int). + Source https://stackoverflow.com/a/58367726/966058 + + >>> class Priority(str, OrderedEnum): + ... LOW = "LOW" + ... MEDIUM = "MEDIUM" + ... HIGH = "HIGH" + >>> Priority.LOW < Priority.HIGH + True + >>> Priority.HIGH > Priority.MEDIUM + True + >>> max(Priority.LOW, Priority.HIGH) == Priority.HIGH + True + >>> # str ordering would give "MEDIUM" > "LOW" > "HIGH" (lexicographic), + >>> # but OrderedEnum uses definition order: LOW < MEDIUM < HIGH + >>> max(Priority.LOW, Priority.HIGH) == Priority.HIGH + True + >>> [p.value for p in sorted([Priority.HIGH, Priority.LOW, Priority.MEDIUM])] + ['LOW', 'MEDIUM', 'HIGH'] """ def __init__(self, value, *args, **kwds): super().__init__(*args, **kwds) self.__order = len(self.__class__) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Inject comparison methods directly onto each subclass so they take + # MRO priority over data-type mixins like str or int. + # Without this, `class Foo(str, OrderedEnum)` would use str's + # lexicographic comparisons instead of definition-order comparisons. + for name in ("__gt__", "__ge__", "__lt__", "__le__"): + setattr(cls, name, getattr(OrderedEnum, name)) + def __ge__(self, other): if self.__class__ is other.__class__: return self.__order >= other.__order From e3134a129cc121c75bfc7ae2bc966185358a8aaf Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 25 Feb 2026 17:18:21 -0800 Subject: [PATCH 15/30] fix: correct misleading error log about NATS redelivery The NATS message is ACK'd at line 145, before update_state() and _update_job_progress(). If either of those raises, the except block was logging "NATS will redeliver" when it won't. Co-Authored-By: Claude --- ami/jobs/tasks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 4fae09a65..d3f26c1f6 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -172,7 +172,8 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub except Exception as e: job.logger.error( - f"Failed to process pipeline result for job {job_id}: {e}. NATS will redeliver the task message." + f"Failed to process pipeline result for job {job_id}: {e}. " + "NATS message was already acknowledged; it will not be redelivered." ) From 94e1bbb25b5eff7550696f7f7f5f66a45f792aaf Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Thu, 26 Feb 2026 09:10:13 -0800 Subject: [PATCH 16/30] Use job.logger --- ami/jobs/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 4c4547120..35b397b3e 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -150,7 +150,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) if not progress_info: - _cancel_job_on_missing_state(job_id, logger) + _cancel_job_on_missing_state(job_id, job.logger) return # update complete state based on latest progress info after saving results From dcf57fe205844ca53fbb9a8e7223961be7217c6d Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Thu, 26 Feb 2026 09:11:44 -0800 Subject: [PATCH 17/30] Use job.logger --- ami/jobs/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 2d4f9e8b9..5a3ba6d34 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -153,7 +153,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) if not progress_info: - logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + job.logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") # TODO: cancel the job to fail fast once PR #1144 is merged return From 4a25e549e54028773be76654284b22f26c3a44aa Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Thu, 26 Feb 2026 10:43:32 -0800 Subject: [PATCH 18/30] Integrate cancellation support --- ami/jobs/models.py | 26 +++++++++++++++----- ami/jobs/tasks.py | 38 ++++++++++++++++++------------ ami/ml/orchestration/jobs.py | 19 ++++++++------- ami/ml/orchestration/nats_queue.py | 31 +++++++++++++++--------- 4 files changed, 73 insertions(+), 41 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index be797dd4f..f18edfbc5 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -15,7 +15,7 @@ from ami.base.models import BaseModel from ami.base.schemas import ConfigurableStage, ConfigurableStageParam -from ami.jobs.tasks import run_job +from ami.jobs.tasks import cleanup_async_job_if_needed, run_job from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.ml.post_processing.registry import get_postprocessing_task @@ -331,7 +331,11 @@ def emit(self, record: logging.LogRecord): # Log to the current app logger logger.log(record.levelno, self.format(record)) - # Write to the logs field on the job instance + # Write to the logs field on the job instance. + # Refresh from DB first to reduce the window for concurrent overwrites — each + # worker holds its own stale in-memory copy of `logs`, so without a refresh the + # last writer always wins and earlier entries are silently dropped. + self.job.refresh_from_db(fields=["logs"]) timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") msg = f"[{timestamp}] {record.levelname} {self.format(record)}" if msg not in self.job.logs.stdout: @@ -350,7 +354,6 @@ def emit(self, record: logging.LogRecord): self.job.save(update_fields=["logs"], update_progress=False) except Exception as e: logger.error(f"Failed to save logs for job #{self.job.pk}: {e}") - pass @dataclass @@ -970,11 +973,18 @@ def cancel(self): """ self.status = JobState.CANCELING self.save() + + cleanup_async_job_if_needed(self) if self.task_id: task = run_job.AsyncResult(self.task_id) if task: task.revoke(terminate=True) self.save() + if self.dispatch_mode == JobDispatchMode.ASYNC_API: + # For async jobs we need to set the status to revoked here since the task already + # finished (it only queues the images). + self.status = JobState.REVOKED + self.save() else: self.status = JobState.REVOKED self.save() @@ -1084,11 +1094,15 @@ def get_default_progress(cls) -> JobProgress: def logger(self) -> logging.Logger: _logger = logging.getLogger(f"ami.jobs.{self.pk}") - # Only add JobLogHandler if not already present - if not any(isinstance(h, JobLogHandler) for h in _logger.handlers): - # Also log output to a field on thie model instance + # Update or add JobLogHandler, always pointing to the current instance. + # The logger is a process-level singleton so its handler may reference a stale + # job instance from a previous task execution in this worker process. + handler = next((h for h in _logger.handlers if isinstance(h, JobLogHandler)), None) + if handler is None: logger.info("Adding JobLogHandler to logger for job %s", self.pk) _logger.addHandler(JobLogHandler(self)) + else: + handler.job = self _logger.propagate = False return _logger diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 35b397b3e..ba3e923c2 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -88,7 +88,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub if not progress_info: # Acknowledge the task to prevent retries _ack_task_via_nats(reply_subject, logger) - _cancel_job_on_missing_state(job_id, logger) + _fail_job(job_id, "Redis state missing for job") return try: @@ -150,7 +150,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) if not progress_info: - _cancel_job_on_missing_state(job_id, job.logger) + _fail_job(job_id, "Redis state missing for job") return # update complete state based on latest progress info after saving results @@ -174,16 +174,24 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) -def _cancel_job_on_missing_state(job_id: int, logger: logging.Logger) -> None: +def _fail_job(job_id: int, reason: str) -> None: from ami.jobs.models import Job, JobState + from ami.ml.orchestration.jobs import cleanup_async_job_resources - logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") - - # cancel job (fail fast) since we don't know the state - job = Job.objects.get(pk=job_id) - if job.status != JobState.CANCELING and job.status not in JobState.final_states(): - job.logger.error(f"Job {job_id} is not canceling or finished, but Redis state is missing. ") - job.cancel() + try: + with transaction.atomic(): + job = Job.objects.select_for_update().get(pk=job_id) + if job.status in (JobState.CANCELING, *JobState.final_states()): + return + job.status = JobState.FAILURE + job.finished_at = datetime.datetime.now() + job.save(update_fields=["status", "finished_at"]) + + job.logger.error(f"Job {job_id} marked as FAILURE: {reason}") + cleanup_async_job_resources(job.pk, job.logger) + except Job.DoesNotExist: + logger.error(f"Cannot fail job {job_id}: not found") + cleanup_async_job_resources(job_id, logger) def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: @@ -289,10 +297,10 @@ def _update_job_progress( # Clean up async resources for completed jobs that use NATS/Redis if job.progress.is_complete(): job = Job.objects.get(pk=job_id) # Re-fetch outside transaction - _cleanup_job_if_needed(job) + cleanup_async_job_if_needed(job) -def _cleanup_job_if_needed(job) -> None: +def cleanup_async_job_if_needed(job) -> None: """ Clean up async resources (NATS/Redis) if this job uses them. @@ -308,7 +316,7 @@ def _cleanup_job_if_needed(job) -> None: # import here to avoid circular imports from ami.ml.orchestration.jobs import cleanup_async_job_resources - cleanup_async_job_resources(job) + cleanup_async_job_resources(job.pk, job.logger) @task_prerun.connect(sender=run_job) @@ -347,7 +355,7 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): # Clean up async resources for revoked jobs if state == JobState.REVOKED: - _cleanup_job_if_needed(job) + cleanup_async_job_if_needed(job) @task_failure.connect(sender=run_job, retry=False) @@ -362,7 +370,7 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.save() # Clean up async resources for failed jobs - _cleanup_job_if_needed(job) + cleanup_async_job_if_needed(job) def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index ce54ecd1c..95c763b1b 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def cleanup_async_job_resources(job: "Job") -> bool: +def cleanup_async_job_resources(job_id: int, _logger: logging.Logger) -> bool: """ Clean up NATS JetStream and Redis resources for a completed job. @@ -22,7 +22,8 @@ def cleanup_async_job_resources(job: "Job") -> bool: Cleanup failures are logged but don't fail the job - data is already saved. Args: - job: The Job instance + job_id: The Job ID (integer primary key) + _logger: Logger to use for logging cleanup results Returns: bool: True if both cleanups succeeded, False otherwise """ @@ -31,26 +32,26 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup Redis state try: - state_manager = AsyncJobStateManager(job.pk) + state_manager = AsyncJobStateManager(job_id) state_manager.cleanup() - job.logger.info(f"Cleaned up Redis state for job {job.pk}") + _logger.info(f"Cleaned up Redis state for job {job_id}") redis_success = True except Exception as e: - job.logger.error(f"Error cleaning up Redis state for job {job.pk}: {e}") + _logger.error(f"Error cleaning up Redis state for job {job_id}: {e}") # Cleanup NATS resources async def cleanup(): async with TaskQueueManager() as manager: - return await manager.cleanup_job_resources(job.pk) + return await manager.cleanup_job_resources(job_id) try: nats_success = async_to_sync(cleanup)() if nats_success: - job.logger.info(f"Cleaned up NATS resources for job {job.pk}") + _logger.info(f"Cleaned up NATS resources for job {job_id}") else: - job.logger.warning(f"Failed to clean up NATS resources for job {job.pk}") + _logger.warning(f"Failed to clean up NATS resources for job {job_id}") except Exception as e: - job.logger.error(f"Error cleaning up NATS resources for job {job.pk}: {e}") + _logger.error(f"Error cleaning up NATS resources for job {job_id}: {e}") return redis_success and nats_success diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index a23d28ac8..f7bfc046b 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -95,21 +95,27 @@ def _get_consumer_name(self, job_id: int) -> str: """Get consumer name from job_id.""" return f"job-{job_id}-consumer" - async def _ensure_stream(self, job_id: int): - """Ensure stream exists for the given job.""" + async def _stream_exists(self, job_id: int) -> bool: + """Check if stream exists for the given job.""" if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") stream_name = self._get_stream_name(job_id) - subject = self._get_subject(job_id) - try: - await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) - logger.debug(f"Stream {stream_name} already exists") - except asyncio.TimeoutError: - raise # NATS unreachable — let caller handle it rather than creating a stream blindly - except Exception as e: - logger.warning(f"Stream {stream_name} does not exist: {e}") + await self.js.stream_info(stream_name) + return True + except nats.js.errors.NotFoundError: + return False + + async def _ensure_stream(self, job_id: int): + """Ensure stream exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + if not await self._stream_exists(job_id): + stream_name = self._get_stream_name(job_id) + subject = self._get_subject(job_id) + logger.warning(f"Stream {stream_name} does not exist") # Stream doesn't exist, create it await asyncio.wait_for( self.js.add_stream( @@ -207,7 +213,10 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: - await self._ensure_stream(job_id) + if not await self._stream_exists(job_id): + logger.debug(f"Stream for job '{job_id}' does not exist when reserving task") + return [] + await self._ensure_consumer(job_id) consumer_name = self._get_consumer_name(job_id) From 5d38d67c053032a6fc3866d1958c8f3c7f2a4077 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Thu, 26 Feb 2026 10:59:37 -0800 Subject: [PATCH 19/30] merge, update tests --- ami/ml/orchestration/tests/test_nats_queue.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index a7bd91b68..cf3514bce 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -3,6 +3,8 @@ import unittest from unittest.mock import AsyncMock, MagicMock, patch +import nats + from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.schemas import PipelineProcessingTask @@ -51,8 +53,8 @@ async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() - js.stream_info.side_effect = Exception("Not found") - js.consumer_info.side_effect = Exception("Not found") + js.stream_info.side_effect = nats.js.errors.NotFoundError() + js.consumer_info.side_effect = nats.js.errors.NotFoundError() with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): async with TaskQueueManager() as manager: From ac90c2f064a3847e234766026230db72b1c9ef26 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Thu, 26 Feb 2026 11:36:15 -0800 Subject: [PATCH 20/30] Remove pause support in monkey --- ami/jobs/management/commands/chaos_monkey.py | 54 +------------------- 1 file changed, 2 insertions(+), 52 deletions(-) diff --git a/ami/jobs/management/commands/chaos_monkey.py b/ami/jobs/management/commands/chaos_monkey.py index f1c2793b5..04da595aa 100644 --- a/ami/jobs/management/commands/chaos_monkey.py +++ b/ami/jobs/management/commands/chaos_monkey.py @@ -11,23 +11,12 @@ # Flush all NATS JetStream streams (simulates broker state loss) python manage.py chaos_monkey flush nats - - # Pause Redis for 15 seconds then restore (simulates transient outage) - python manage.py chaos_monkey pause redis - - # Pause NATS for 30 seconds then restore - python manage.py chaos_monkey pause nats --duration 30 """ -import subprocess -import time - from asgiref.sync import async_to_sync from django.core.management.base import BaseCommand, CommandError from django_redis import get_redis_connection -REDIS_CONTAINER = "ami_local_redis" -NATS_CONTAINER = "ami_local_nats" NATS_URL = "nats://ami_local_nats:4222" @@ -37,35 +26,23 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( "action", - choices=["flush", "pause"], - help="flush: wipe all state. pause: stop the service temporarily then restore it.", + choices=["flush"], + help="flush: wipe all state.", ) parser.add_argument( "service", choices=["redis", "nats"], help="Target service to fault.", ) - parser.add_argument( - "--duration", - type=int, - default=15, - metavar="SECONDS", - help="How long to keep the service paused before restoring (pause only, default: 15).", - ) def handle(self, *args, **options): action = options["action"] service = options["service"] - duration = options["duration"] if action == "flush" and service == "redis": self._flush_redis() elif action == "flush" and service == "nats": self._flush_nats() - elif action == "pause" and service == "redis": - self._pause_container(REDIS_CONTAINER, duration) - elif action == "pause" and service == "nats": - self._pause_container(NATS_CONTAINER, duration) # ------------------------------------------------------------------ # Redis @@ -117,30 +94,3 @@ async def _delete_all_streams(): self.stdout.write(self.style.SUCCESS(f"Deleted {len(deleted)} stream(s).")) else: self.stdout.write("No streams found — NATS already empty.") - - # ------------------------------------------------------------------ - # Container pause/unpause (works for both redis and nats) - # ------------------------------------------------------------------ - - def _pause_container(self, container: str, duration: int): - self.stdout.write(f"Pausing container '{container}' for {duration}s...") - self._docker("pause", container) - self.stdout.write(self.style.WARNING(f"Container paused. Waiting {duration}s...")) - - for remaining in range(duration, 0, -1): - self.stdout.write(f"\r {remaining}s remaining...", ending="") - self.stdout.flush() - time.sleep(1) - - self.stdout.write("") # newline after countdown - self._docker("unpause", container) - self.stdout.write(self.style.SUCCESS(f"Container '{container}' restored.")) - - def _docker(self, subcommand: str, container: str): - result = subprocess.run( - ["docker", subcommand, container], - capture_output=True, - text=True, - ) - if result.returncode != 0: - raise CommandError(f"`docker {subcommand} {container}` failed:\n{result.stderr.strip()}") From 4eb763a6b3725da8f25a76c0614ae5e53c8d9d82 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 26 Feb 2026 16:38:02 -0800 Subject: [PATCH 21/30] fix: cancel async jobs by cleaning up NATS/Redis and stopping task delivery For async_api jobs, the Celery task completes after queuing images to NATS, so task.revoke() has no effect. The worker kept pulling tasks via the /tasks endpoint because it only checked final_states(), not CANCELING. - Add JobState.active_states() (STARTED, RETRY) for positive task-serving check - /tasks endpoint returns empty unless job is in active_states() - Job.cancel() for async_api jobs: clean up NATS/Redis, then set REVOKED Co-Authored-By: Claude --- ami/jobs/models.py | 9 ++++++++- ami/jobs/views.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index f18edfbc5..fd3a9168e 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -88,6 +88,11 @@ def final_states(cls): def failed_states(cls): return [cls.FAILURE, cls.REVOKED, cls.UNKNOWN] + @classmethod + def active_states(cls): + """States where a job is actively processing and should serve tasks to workers.""" + return [cls.STARTED, cls.RETRY] + def get_status_label(status: JobState, progress: float) -> str: """ @@ -969,7 +974,9 @@ def retry(self, async_task=True): def cancel(self): """ - Terminate the celery task. + Cancel a job. For async_api jobs, clean up NATS/Redis resources + and transition through CANCELING → REVOKED. For other jobs, + revoke the Celery task. """ self.status = JobState.CANCELING self.save() diff --git a/ami/jobs/views.py b/ami/jobs/views.py index cdd1886e7..6d0626f23 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -237,8 +237,8 @@ def tasks(self, request, pk=None): if job.dispatch_mode != JobDispatchMode.ASYNC_API: raise ValidationError("Only async_api jobs have fetchable tasks") - # Don't fetch tasks from completed/failed/revoked jobs - if job.status in JobState.final_states() or job.status == JobState.CANCELING: + # Only serve tasks for actively processing jobs + if job.status not in JobState.active_states(): return Response({"tasks": []}) # Validate that the job has a pipeline From 867121433289be78acf081ecf761b5a408986a78 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 26 Feb 2026 16:38:09 -0800 Subject: [PATCH 22/30] fix(ui): hide Retry button while job is in CANCELING state canRetry now excludes CANCELING so the Retry button stays hidden during the drain period, matching the backend's transitional state. Co-Authored-By: Claude --- ui/src/data-services/models/job.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ui/src/data-services/models/job.ts b/ui/src/data-services/models/job.ts index 625db4ec7..41b1c8ffd 100644 --- a/ui/src/data-services/models/job.ts +++ b/ui/src/data-services/models/job.ts @@ -66,7 +66,8 @@ export class Job { this._job.user_permissions.includes(UserPermission.Run) && this.status.code !== 'CREATED' && this.status.code !== 'STARTED' && - this.status.code !== 'PENDING' + this.status.code !== 'PENDING' && + this.status.code !== 'CANCELING' ) } From b1146cc7ffe2697e76fde4c856cd5ad955207739 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 26 Feb 2026 16:38:15 -0800 Subject: [PATCH 23/30] fix: downgrade Redis-missing log to warning for canceled jobs When a job is canceled, NATS/Redis cleanup runs before in-flight results finish processing. The resulting "Redis state missing" message is expected, not an error. Co-Authored-By: Claude --- ami/jobs/tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 1725ce410..adfce9c85 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -86,9 +86,9 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids) if not progress_info: - # Acknowledge the task to prevent retries + logger.warning(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") + # Acknowledge the task to prevent retries, since we don't know the state _ack_task_via_nats(reply_subject, logger) - _fail_job(job_id, "Redis state missing for job") return try: @@ -152,7 +152,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) if not progress_info: - _fail_job(job_id, "Redis state missing for job") + job.logger.warning(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.") return # update complete state based on latest progress info after saving results From dccaceb0b261bd13937d29f2f52866848cd51931 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 26 Feb 2026 16:38:46 -0800 Subject: [PATCH 24/30] docs: add async job monitoring reference Covers all monitoring points for NATS async jobs: Django ORM, REST API, tasks endpoint, NATS consumer state, Redis counters, Docker logs, and AMI worker logs. Linked from CLAUDE.md and the test_ml_job_e2e command. Co-Authored-By: Claude --- .agents/AGENTS.md | 10 + .../management/commands/test_ml_job_e2e.py | 6 +- .../claude/reference/monitoring-async-jobs.md | 205 ++++++++++++++++++ 3 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 docs/claude/reference/monitoring-async-jobs.md diff --git a/.agents/AGENTS.md b/.agents/AGENTS.md index 8acceeb4d..1b6ac558e 100644 --- a/.agents/AGENTS.md +++ b/.agents/AGENTS.md @@ -650,6 +650,16 @@ images = SourceImage.objects.annotate(det_count=Count('detections')) - Use `@shared_task` decorator for all tasks - Check Flower UI for debugging: http://localhost:5555 +### E2E Testing & Monitoring Async Jobs + +Run an end-to-end ML job test: +```bash +docker compose run --rm django python manage.py test_ml_job_e2e \ + --project 18 --dispatch-mode async_api --collection 142 --pipeline "global_moths_2024" +``` + +For monitoring running jobs (Django ORM, REST API, NATS consumer state, Redis counters, worker logs, etc.), see `docs/claude/reference/monitoring-async-jobs.md`. + ### Running a Single Test ```bash diff --git a/ami/jobs/management/commands/test_ml_job_e2e.py b/ami/jobs/management/commands/test_ml_job_e2e.py index 2f613e39c..f79c54fbb 100644 --- a/ami/jobs/management/commands/test_ml_job_e2e.py +++ b/ami/jobs/management/commands/test_ml_job_e2e.py @@ -10,7 +10,11 @@ class Command(BaseCommand): - help = "Run end-to-end test of ML job processing" + help = ( + "Run end-to-end test of ML job processing.\n\n" + "For monitoring and debugging running jobs, see:\n" + " docs/claude/reference/monitoring-async-jobs.md" + ) def add_arguments(self, parser): parser.add_argument("--project", type=int, required=True, help="Project ID") diff --git a/docs/claude/reference/monitoring-async-jobs.md b/docs/claude/reference/monitoring-async-jobs.md new file mode 100644 index 000000000..339e2d100 --- /dev/null +++ b/docs/claude/reference/monitoring-async-jobs.md @@ -0,0 +1,205 @@ +# Monitoring Async (NATS) Jobs + +Reference for monitoring and debugging async_api jobs that use NATS JetStream for task distribution to external workers (e.g., AMI Data Companion). + +## Starting a Test Job + +```bash +docker compose run --rm django python manage.py test_ml_job_e2e \ + --project 18 \ + --dispatch-mode async_api \ + --collection 142 \ + --pipeline "global_moths_2024" +``` + +Or create a job via the UI at http://localhost:4000/projects/18/jobs. + +## Monitoring Points + +### 1. Web UI + +**Job details page:** `http://localhost:4000/projects/{PROJECT_ID}/jobs/{JOB_ID}` + +Shows status bar, progress percentage, stage breakdown, and logs. Polls the API automatically. + +### 2. Jobs REST API + +```bash +# Get auth token +TOKEN=$(docker compose exec django python manage.py shell -c \ + "from rest_framework.authtoken.models import Token; print(Token.objects.first().key)" 2>/dev/null) + +# Job status & progress summary +curl -s http://localhost:8000/api/v2/jobs/{JOB_ID}/ \ + -H "Authorization: Token $TOKEN" | jq '{id, status, dispatch_mode, progress: .progress.summary}' + +# Full stage breakdown +curl -s http://localhost:8000/api/v2/jobs/{JOB_ID}/ \ + -H "Authorization: Token $TOKEN" | jq '.progress.stages[] | {key: .key, status: .status, progress: .progress}' +``` + +### 3. Tasks Endpoint (Worker-Facing) + +This is what the external worker polls to get batches of images to process. + +```bash +# See what the worker would get (fetches from NATS, reserves tasks) +curl -s "http://localhost:8000/api/v2/jobs/{JOB_ID}/tasks/?batch=8" \ + -H "Authorization: Token $TOKEN" | jq '.tasks | length' + +# Returns empty [] when job is not in active_states (STARTED, RETRY) +# i.e. returns empty for CANCELING, REVOKED, SUCCESS, FAILURE, etc. +``` + +### 4. Django ORM (Shell) + +```bash +docker compose exec django python manage.py shell -c " +from ami.jobs.models import Job +j = Job.objects.get(pk={JOB_ID}) +print(f'Status: {j.status}') +print(f'Dispatch mode: {j.dispatch_mode}') +print(f'Progress: {j.progress.summary.progress*100:.1f}%') +print(f'Started: {j.started_at}') +print(f'Finished: {j.finished_at}') +for s in j.progress.stages: + print(f' {s.key}: {s.status} {s.progress*100:.1f}%') +" +``` + +### 5. NATS JetStream Consumer State + +Shows the queue depth, in-flight tasks, and acknowledgment progress. + +```bash +docker compose exec django python manage.py shell -c " +from ami.ml.orchestration.nats_queue import TaskQueueManager +import asyncio +async def check(): + async with TaskQueueManager() as m: + info = await m.js.consumer_info('job_{JOB_ID}', 'job-{JOB_ID}-consumer') + print(f'num_pending: {info.num_pending}') # Tasks waiting in queue + print(f'num_ack_pending: {info.num_ack_pending}') # Tasks reserved but not yet ACKed + print(f'num_redelivered: {info.num_redelivered}') # Tasks redelivered after timeout + print(f'delivered.seq: {info.delivered.stream_seq}') # Last delivered sequence + print(f'ack_floor.seq: {info.ack_floor.stream_seq}') # Last contiguous ACK +asyncio.run(check()) +" +``` + +Key fields: +- `num_pending` = tasks still in queue, not yet reserved by any worker +- `num_ack_pending` = tasks reserved by worker, waiting for result POST + ACK +- `num_redelivered` = tasks that timed out (TTR=30s default) and were redelivered +- When `num_pending=0` and `num_ack_pending=0`, all tasks have been processed + +### 6. Redis State (Atomic Progress Counters) + +Tracks per-stage progress independently of the Job model. Updated atomically by Celery result tasks. + +```bash +docker compose exec django python manage.py shell -c " +from ami.ml.orchestration.async_job_state import AsyncJobStateManager +sm = AsyncJobStateManager({JOB_ID}) +for stage in sm.STAGES: + prog = sm.get_progress(stage) + print(f'{stage}: remaining={prog.remaining} processed={prog.processed}/{prog.total} failed={prog.failed} ({prog.percentage*100:.1f}%)') +" +``` + +### 7. Django Logs (Docker Compose) + +```bash +# All django logs (includes task reservations and result processing) +docker compose logs -f django + +# Filter for specific job +docker compose logs -f django 2>&1 | grep "1408" + +# Filter for task reservations +docker compose logs -f django 2>&1 | grep "Reserved" + +# Filter for result processing +docker compose logs -f django 2>&1 | grep "Queued pipeline result" +``` + +### 8. Celery Worker Logs + +```bash +# Celery worker logs (result saving, NATS ACKs, progress updates) +docker compose logs -f celeryworker + +# Filter for specific job +docker compose logs -f celeryworker 2>&1 | grep "job 1408" +``` + +### 9. AMI Worker Logs (External) + +The AMI Data Companion worker runs outside Docker. Check its terminal output for: +- Batch processing progress (e.g., "Finished batch 84. Total items: 672") +- Model inference times (detection + classification) +- Connection errors to Django API or NATS + +```bash +# If running via conda +conda activate ami-py311 +ami worker --pipeline global_moths_2024 + +# Worker registration (loads ML models, ~20s) +ami worker register "local-worker" --project 18 +``` + +## Continuous Monitoring (Watch Loop) + +```bash +# Poll job status every 5 seconds +watch -n 5 'docker compose exec django python manage.py shell -c " +from ami.jobs.models import Job +j = Job.objects.get(pk={JOB_ID}) +print(f\"Status: {j.status} | Progress: {j.progress.summary.progress*100:.1f}%\") +for s in j.progress.stages: + print(f\" {s.key}: {s.status} {s.progress*100:.1f}%\") +"' +``` + +## Job Lifecycle (async_api) + +``` +CREATED → PENDING → STARTED → [processing] → SUCCESS + ↓ + CANCELING → REVOKED (user cancels) + ↓ + FAILURE (error during processing) +``` + +1. **STARTED**: Celery task collects images, publishes to NATS stream, then returns +2. **Processing**: Worker polls `/tasks`, processes batches, POSTs to `/result/` +3. **SUCCESS**: All results received, progress reaches 100% +4. **CANCELING → REVOKED**: User cancels, NATS stream/consumer deleted, status set to REVOKED. In-flight results may still trickle in and are saved. + +## Key Configuration + +| Setting | Default | Source | +|---------|---------|--------| +| NATS task TTR (visibility timeout) | 30s | `NATS_TASK_TTR` env var | +| NATS max_ack_pending | 1000 | `NATS_MAX_ACK_PENDING` env var | +| NATS max_deliver (retries) | 5 | hardcoded in `nats_queue.py` | +| NATS stream retention | 24h | hardcoded in `nats_queue.py` | +| Worker batch size | varies | worker's `?batch=N` param | + +## Troubleshooting + +**Job stuck in STARTED with no progress:** +- Check if worker is running and connected +- Check NATS consumer state — if `num_pending > 0` but nothing is being delivered, worker may have lost connection +- Check `num_redelivered` — high count means tasks are timing out (worker too slow or crashing) + +**Job stuck in CANCELING:** +- Pre-fix: job was stuck because `/tasks` kept serving tasks and nothing transitioned to REVOKED +- Post-fix: `cancel()` cleans up NATS resources and sets REVOKED synchronously +- If still stuck, the periodic `check_incomplete_jobs` beat task (PR #1025) will catch it + +**Results not being saved:** +- Check celeryworker logs for errors in `process_nats_pipeline_result` +- Check Redis state — if `process` is ahead of `results`, Celery is backed up saving results +- Check NATS `num_ack_pending` — high count means results haven't been ACKed yet From d63be48412261e0a43ef9284fb84f8823ea59a54 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 26 Feb 2026 18:16:49 -0800 Subject: [PATCH 25/30] fix: update tests for active_states() guard on /tasks endpoint Tests need to set job status to STARTED since the /tasks endpoint now only serves tasks for jobs in active_states() (STARTED, RETRY). Co-Authored-By: Claude --- ami/jobs/tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 033a08b5c..65bf1e6f1 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -445,7 +445,8 @@ def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() job = self._create_ml_job("Job for batch test", pipeline) job.dispatch_mode = JobDispatchMode.ASYNC_API - job.save(update_fields=["dispatch_mode"]) + job.status = JobState.STARTED + job.save(update_fields=["dispatch_mode", "status"]) images = [ SourceImage.objects.create( path=f"image_{i}.jpg", @@ -487,6 +488,7 @@ def test_tasks_endpoint_without_pipeline(self): name="Job without pipeline", source_image_collection=self.source_image_collection, dispatch_mode=JobDispatchMode.ASYNC_API, + status=JobState.STARTED, ) self.client.force_authenticate(user=self.user) From f4d88ff3f4cdae67926b48a27057c0d7d2818804 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 27 Feb 2026 15:49:41 -0800 Subject: [PATCH 26/30] fix: improve job cancel ordering, fail status sync, and log handler safety - Reorder cancel(): revoke Celery task before cleaning up async resources to prevent a theoretical race where a worker recreates state after cleanup - Remove redundant self.save() after task.revoke() (no fields changed) - Use update_status() in _fail_job() to keep progress.summary.status in sync with job.status - Wrap entire log handler emit() DB sequence (refresh_from_db + mutations + save) in try/except so a DB failure during logging cannot crash callers Co-Authored-By: Claude --- ami/jobs/models.py | 30 +++++++++++++++--------------- ami/jobs/tasks.py | 4 ++-- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index fd3a9168e..89be29312 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -340,22 +340,22 @@ def emit(self, record: logging.LogRecord): # Refresh from DB first to reduce the window for concurrent overwrites — each # worker holds its own stale in-memory copy of `logs`, so without a refresh the # last writer always wins and earlier entries are silently dropped. - self.job.refresh_from_db(fields=["logs"]) - timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - msg = f"[{timestamp}] {record.levelname} {self.format(record)}" - if msg not in self.job.logs.stdout: - self.job.logs.stdout.insert(0, msg) + # @TODO consider saving logs to the database periodically rather than on every log + try: + self.job.refresh_from_db(fields=["logs"]) + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + msg = f"[{timestamp}] {record.levelname} {self.format(record)}" + if msg not in self.job.logs.stdout: + self.job.logs.stdout.insert(0, msg) - # Write a simpler copy of any errors to the errors field - if record.levelno >= logging.ERROR: - if record.message not in self.job.logs.stderr: - self.job.logs.stderr.insert(0, record.message) + # Write a simpler copy of any errors to the errors field + if record.levelno >= logging.ERROR: + if record.message not in self.job.logs.stderr: + self.job.logs.stderr.insert(0, record.message) - if len(self.job.logs.stdout) > self.max_log_length: - self.job.logs.stdout = self.job.logs.stdout[: self.max_log_length] + if len(self.job.logs.stdout) > self.max_log_length: + self.job.logs.stdout = self.job.logs.stdout[: self.max_log_length] - # @TODO consider saving logs to the database periodically rather than on every log - try: self.job.save(update_fields=["logs"], update_progress=False) except Exception as e: logger.error(f"Failed to save logs for job #{self.job.pk}: {e}") @@ -981,12 +981,10 @@ def cancel(self): self.status = JobState.CANCELING self.save() - cleanup_async_job_if_needed(self) if self.task_id: task = run_job.AsyncResult(self.task_id) if task: task.revoke(terminate=True) - self.save() if self.dispatch_mode == JobDispatchMode.ASYNC_API: # For async jobs we need to set the status to revoked here since the task already # finished (it only queues the images). @@ -996,6 +994,8 @@ def cancel(self): self.status = JobState.REVOKED self.save() + cleanup_async_job_if_needed(self) + def update_status(self, status=None, save=True): """ Update the status of the job based on the status of the celery task. diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 7c201a529..917608be0 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -187,9 +187,9 @@ def _fail_job(job_id: int, reason: str) -> None: job = Job.objects.select_for_update().get(pk=job_id) if job.status in (JobState.CANCELING, *JobState.final_states()): return - job.status = JobState.FAILURE + job.update_status(JobState.FAILURE, save=False) job.finished_at = datetime.datetime.now() - job.save(update_fields=["status", "finished_at"]) + job.save(update_fields=["status", "progress", "finished_at"]) job.logger.error(f"Job {job_id} marked as FAILURE: {reason}") cleanup_async_job_resources(job.pk, job.logger) From 20e4ec2234c01ed96012cfdfdb02727c03e1f09e Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 27 Feb 2026 15:49:53 -0800 Subject: [PATCH 27/30] fix: restore timeout on _stream_exists and use settings for NATS_URL - Add asyncio.wait_for() wrapper to _stream_exists() stream_info call, accidentally dropped during refactor from _ensure_stream - Read NATS_URL from Django settings in chaos_monkey command instead of hardcoding, consistent with TaskQueueManager Co-Authored-By: Claude --- ami/jobs/management/commands/chaos_monkey.py | 3 ++- ami/ml/orchestration/nats_queue.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ami/jobs/management/commands/chaos_monkey.py b/ami/jobs/management/commands/chaos_monkey.py index 04da595aa..50ad3c6ab 100644 --- a/ami/jobs/management/commands/chaos_monkey.py +++ b/ami/jobs/management/commands/chaos_monkey.py @@ -14,10 +14,11 @@ """ from asgiref.sync import async_to_sync +from django.conf import settings from django.core.management.base import BaseCommand, CommandError from django_redis import get_redis_connection -NATS_URL = "nats://ami_local_nats:4222" +NATS_URL = getattr(settings, "NATS_URL", "nats://nats:4222") class Command(BaseCommand): diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index f7bfc046b..9b93104a5 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -102,7 +102,7 @@ async def _stream_exists(self, job_id: int) -> bool: stream_name = self._get_stream_name(job_id) try: - await self.js.stream_info(stream_name) + await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) return True except nats.js.errors.NotFoundError: return False From cf18987273abebf40eff5e62cda3ed6e7a8b1637 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 27 Feb 2026 15:50:02 -0800 Subject: [PATCH 28/30] fix(ui): block retry button while job is in RETRY state RETRY is an active processing state; allowing another retry while one is already running could cause duplicate execution. Co-Authored-By: Claude --- ui/src/data-services/models/job.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ui/src/data-services/models/job.ts b/ui/src/data-services/models/job.ts index 41b1c8ffd..66302d43e 100644 --- a/ui/src/data-services/models/job.ts +++ b/ui/src/data-services/models/job.ts @@ -67,7 +67,8 @@ export class Job { this.status.code !== 'CREATED' && this.status.code !== 'STARTED' && this.status.code !== 'PENDING' && - this.status.code !== 'CANCELING' + this.status.code !== 'CANCELING' && + this.status.code !== 'RETRY' ) } From f1bed5e308522b27f5a3342659ca677aacb4c4c3 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 27 Feb 2026 16:37:57 -0800 Subject: [PATCH 29/30] docs: clarify _stream_exists timeout propagation design MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add docstring explaining that TimeoutError is deliberately not caught — an unreachable NATS server should be a hard failure, not a "stream missing" false negative. Multiple reviewers questioned this behavior. Co-Authored-By: Claude --- ami/ml/orchestration/nats_queue.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 9b93104a5..884676637 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -96,7 +96,12 @@ def _get_consumer_name(self, job_id: int) -> str: return f"job-{job_id}-consumer" async def _stream_exists(self, job_id: int) -> bool: - """Check if stream exists for the given job.""" + """Check if stream exists for the given job. + + Only catches NotFoundError (→ False). TimeoutError propagates deliberately + so callers treat an unreachable NATS server as a hard failure rather than + a missing stream. + """ if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") From a16fc05f2de788adca69bc545bdc177fa3329827 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 27 Feb 2026 16:38:07 -0800 Subject: [PATCH 30/30] docs: add language tag to fenced code block in monitoring guide Co-Authored-By: Claude --- docs/claude/reference/monitoring-async-jobs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/claude/reference/monitoring-async-jobs.md b/docs/claude/reference/monitoring-async-jobs.md index 339e2d100..946ddadaa 100644 --- a/docs/claude/reference/monitoring-async-jobs.md +++ b/docs/claude/reference/monitoring-async-jobs.md @@ -164,7 +164,7 @@ for s in j.progress.stages: ## Job Lifecycle (async_api) -``` +```text CREATED → PENDING → STARTED → [processing] → SUCCESS ↓ CANCELING → REVOKED (user cancels)