Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from django.db import migrations


def create_periodic_tasks(apps, schema_editor):
from django_celery_beat.models import CrontabSchedule, PeriodicTask

Comment thread
coderabbitai[bot] marked this conversation as resolved.
stale_schedule, _ = CrontabSchedule.objects.get_or_create(
minute="*/15",
hour="*",
day_of_week="*",
day_of_month="*",
month_of_year="*",
)
PeriodicTask.objects.get_or_create(
name="jobs.check_stale_jobs",
defaults={
"task": "ami.jobs.tasks.check_stale_jobs_task",
"crontab": stale_schedule,
"description": "Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS",
},
)

stats_schedule, _ = CrontabSchedule.objects.get_or_create(
minute="*/5",
hour="*",
day_of_week="*",
day_of_month="*",
month_of_year="*",
)
PeriodicTask.objects.get_or_create(
name="jobs.log_running_async_job_stats",
defaults={
"task": "ami.jobs.tasks.log_running_async_job_stats",
"crontab": stats_schedule,
"description": "Log NATS consumer delivered/ack/pending stats for each running async_api job",
},
)


def delete_periodic_tasks(apps, schema_editor):
from django_celery_beat.models import PeriodicTask

PeriodicTask.objects.filter(
name__in=["jobs.check_stale_jobs", "jobs.log_running_async_job_stats"],
).delete()


class Migration(migrations.Migration):
dependencies = [
("jobs", "0019_job_dispatch_mode"),
]

operations = [
migrations.RunPython(create_periodic_tasks, delete_periodic_tasks),
]
63 changes: 63 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,69 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di
return results


# Beat schedule is every 15 minutes for check_stale_jobs; expire queued copies
# that accumulate while a worker is unavailable so we don't process a backlog.
_STALE_JOB_BEAT_EXPIRES = 60 * 10


@celery_app.task(soft_time_limit=300, time_limit=360, expires=_STALE_JOB_BEAT_EXPIRES)
def check_stale_jobs_task() -> dict:
Comment thread
mihow marked this conversation as resolved.
Outdated
"""Celery Beat entry point for `check_stale_jobs`.

Runs the existing stale-job reconciler on a schedule so jobs don't silently
sit in a running state for days when their Celery task is gone or the
worker crashed. Returns a summary dict for flower / task-result visibility.
"""
results = check_stale_jobs()
updated = sum(1 for r in results if r["action"] == "updated")
revoked = sum(1 for r in results if r["action"] == "revoked")
logger.info(
"check_stale_jobs_task finished: %d stale job(s), %d updated from Celery, %d revoked",
len(results),
updated,
revoked,
)
return {"total": len(results), "updated": updated, "revoked": revoked}


# Expire faster than the stale-job task — this is observability, a skipped
# cycle is fine and we'd rather not pile up backlog of snapshot work.
_ASYNC_STATS_BEAT_EXPIRES = 60 * 4


@celery_app.task(soft_time_limit=180, time_limit=240, expires=_ASYNC_STATS_BEAT_EXPIRES)
def log_running_async_job_stats() -> dict:
"""Log a NATS consumer snapshot (delivered/ack/pending/redelivered) per running async_api job.

Writes to the per-job logger so operators see counts in the job's UI log
without waiting for it to finish. Read-only: no status changes.
"""
from ami.jobs.models import Job, JobDispatchMode, JobState

# Resolve each job's per-job logger synchronously — the property touches Django
# ORM via its JobLogHandler, which is only safe outside the event loop.
running_jobs = list(
Job.objects.filter(
status__in=JobState.running_states(),
dispatch_mode=JobDispatchMode.ASYNC_API,
)
)
if not running_jobs:
return {"checked": 0}

async def _snapshot_all():
for job in running_jobs:
try:
async with TaskQueueManager(job_logger=job.logger) as manager:
await manager.log_consumer_stats_snapshot(job.pk)
except Exception:
# One job's NATS failure must not block snapshots for others.
logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk)
Comment thread
mihow marked this conversation as resolved.
Outdated

async_to_sync(_snapshot_all)()
return {"checked": len(running_jobs)}


def cleanup_async_job_if_needed(job) -> None:
"""
Clean up async resources (NATS/Redis) if this job uses them.
Expand Down
92 changes: 92 additions & 0 deletions ami/jobs/tests/test_periodic_beat_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from datetime import timedelta
from unittest.mock import AsyncMock, patch

from django.test import TestCase
from django.utils import timezone

from ami.jobs.models import Job, JobDispatchMode, JobState
from ami.jobs.tasks import check_stale_jobs_task, log_running_async_job_stats
from ami.main.models import Project


class CheckStaleJobsTaskTest(TestCase):
def setUp(self):
self.project = Project.objects.create(name="Beat schedule test project")

def _create_stale_job(self, status=JobState.STARTED, hours_ago=100):
job = Job.objects.create(project=self.project, name="stale", status=status)
Job.objects.filter(pk=job.pk).update(updated_at=timezone.now() - timedelta(hours=hours_ago))
job.refresh_from_db()
return job

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_returns_summary_counts(self, _mock_cleanup):
self._create_stale_job()
self._create_stale_job()
result = check_stale_jobs_task()
self.assertEqual(result, {"total": 2, "updated": 0, "revoked": 2})

def test_no_stale_jobs_returns_zero_summary(self):
self._create_stale_job(hours_ago=1) # recent — not stale
self.assertEqual(check_stale_jobs_task(), {"total": 0, "updated": 0, "revoked": 0})


class LogRunningAsyncJobStatsTest(TestCase):
def setUp(self):
self.project = Project.objects.create(name="Async snapshot test project")

def _create_async_job(self, status=JobState.STARTED):
job = Job.objects.create(project=self.project, name=f"async {status}", status=status)
Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API)
job.refresh_from_db()
return job

def test_no_running_jobs_short_circuits(self):
# A celery job with async dispatch but a final status should be skipped.
self._create_async_job(status=JobState.SUCCESS)
self.assertEqual(log_running_async_job_stats(), {"checked": 0})

@patch("ami.jobs.tasks.TaskQueueManager")
def test_snapshots_each_running_async_job(self, mock_manager_cls):
job_a = self._create_async_job()
job_b = self._create_async_job()

instance = mock_manager_cls.return_value
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.log_consumer_stats_snapshot = AsyncMock()

result = log_running_async_job_stats()

self.assertEqual(result, {"checked": 2})
snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list]
self.assertCountEqual(snapshots, [job_a.pk, job_b.pk])

@patch("ami.jobs.tasks.TaskQueueManager")
def test_one_job_failure_does_not_block_others(self, mock_manager_cls):
job_ok = self._create_async_job()
job_broken = self._create_async_job()

instance = mock_manager_cls.return_value
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)

calls = []

async def _snapshot(job_id):
calls.append(job_id)
if job_id == job_broken.pk:
raise RuntimeError("nats down for this one")

instance.log_consumer_stats_snapshot = AsyncMock(side_effect=_snapshot)

result = log_running_async_job_stats()
self.assertEqual(result, {"checked": 2})
self.assertIn(job_ok.pk, calls)
self.assertIn(job_broken.pk, calls)

def test_non_async_jobs_skipped(self):
job = Job.objects.create(project=self.project, name="sync job", status=JobState.STARTED)
# default dispatch_mode should not be ASYNC_API
self.assertNotEqual(job.dispatch_mode, JobDispatchMode.ASYNC_API)
self.assertEqual(log_running_async_job_stats(), {"checked": 0})
24 changes: 18 additions & 6 deletions ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,18 @@ async def _log_final_consumer_stats(self, job_id: int) -> None:
redelivered before the consumer vanished. Failures here must NOT block
cleanup — if the consumer or stream is already gone, just skip it.
"""
await self._log_consumer_stats(job_id, prefix="Finalizing NATS consumer", suffix="before deletion")

async def log_consumer_stats_snapshot(self, job_id: int) -> None:
"""Log a mid-flight snapshot of the consumer state for a running job.

Used by the periodic `log_running_async_job_stats` beat task so operators
can see deliver/ack/pending counts without waiting for the job to finish.
Tolerant of missing stream/consumer like the cleanup-time variant.
"""
await self._log_consumer_stats(job_id, prefix="NATS consumer status")

async def _log_consumer_stats(self, job_id: int, *, prefix: str, suffix: str = "") -> None:
if self.js is None:
return
stream_name = self._get_stream_name(job_id)
Expand All @@ -487,15 +499,15 @@ async def _log_final_consumer_stats(self, job_id: int) -> None:
timeout=NATS_JETSTREAM_TIMEOUT,
)
except Exception as e:
# Broad catch is intentional here (unlike _ensure_consumer): at
# cleanup time we tolerate any failure — stream gone, consumer
# already deleted, auth, timeout — so the delete calls below
# still get a chance to run.
logger.debug(f"Could not fetch consumer info for {consumer_name} before deletion: {e}")
# Broad catch is intentional: if the consumer or stream is gone we
# just skip — callers (cleanup, periodic snapshot) should never fail
# because we couldn't read stats.
logger.debug(f"Could not fetch consumer info for {consumer_name}: {e}")
return
tail = f" {suffix}" if suffix else ""
await self.log_async(
logging.INFO,
f"Finalizing NATS consumer {consumer_name} before deletion ({self._format_consumer_stats(info)})",
f"{prefix} {consumer_name}{tail} ({self._format_consumer_stats(info)})",
)

async def delete_consumer(self, job_id: int) -> bool:
Expand Down
38 changes: 38 additions & 0 deletions ami/ml/orchestration/tests/test_nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,44 @@ async def test_publish_failure_surfaces_on_job_logger(self):
f"expected publish failure on job_logger, got {messages}",
)

async def test_log_consumer_stats_snapshot_writes_current_stats(self):
"""The periodic snapshot helper logs delivered/ack/pending WITHOUT
deleting the consumer — it's a mid-flight observability hook."""
nc, js = self._create_mock_nats_connection()
js.consumer_info.return_value = self._make_consumer_info(
delivered=50, ack_floor=40, num_pending=10, num_ack_pending=10, num_redelivered=2
)

job_logger = self._make_captured_logger()
captured = job_logger._captured # type: ignore[attr-defined]

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager(job_logger=job_logger) as manager:
await manager.log_consumer_stats_snapshot(9)

messages = [m for _, m in captured]
self.assertTrue(
any("NATS consumer status job-9-consumer" in m for m in messages),
f"expected snapshot line on job_logger, got {messages}",
)
snapshot_line = next(m for m in messages if "NATS consumer status" in m)
for expected in ("delivered=50", "ack_floor=40", "num_redelivered=2"):
self.assertIn(expected, snapshot_line)
# Must NOT have triggered a delete — this is read-only observability.
js.delete_consumer.assert_not_called()
js.delete_stream.assert_not_called()

async def test_log_consumer_stats_snapshot_tolerates_missing_consumer(self):
"""If the consumer is already gone, the snapshot helper just no-ops."""
nc, js = self._create_mock_nats_connection()
js.consumer_info.side_effect = nats.js.errors.NotFoundError()

job_logger = self._make_captured_logger()

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager(job_logger=job_logger) as manager:
await manager.log_consumer_stats_snapshot(99) # must not raise

async def test_no_job_logger_falls_back_to_module_logger_only(self):
"""When job_logger is None (e.g., module-level uses like advisory
listener), lifecycle logs must still be emitted to the module logger
Expand Down
Loading