Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions ami/ml/orchestration/jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync, sync_to_async

from ami.jobs.models import Job, JobState
from ami.main.models import SourceImage
Expand All @@ -22,8 +22,13 @@ def cleanup_async_job_resources(job_id: int, _logger: logging.Logger) -> bool:
Cleanup failures are logged but don't fail the job - data is already saved.

Args:
job_id: The Job ID (integer primary key)
_logger: Logger to use for logging cleanup results
job_id: The Job ID (integer primary key). For ASYNC_API jobs this should
be called with the per-job logger (``job.logger``) so the UI log shows
cleanup events and the forensic consumer-stats snapshot that
``TaskQueueManager.cleanup_job_resources`` emits before deletion.
_logger: Logger to use for logging cleanup results. Passed through to
TaskQueueManager so lifecycle events land on both the module logger
and the per-job logger.
Returns:
bool: True if both cleanups succeeded, False otherwise
"""
Expand All @@ -39,9 +44,10 @@ def cleanup_async_job_resources(job_id: int, _logger: logging.Logger) -> bool:
except Exception as e:
_logger.error(f"Error cleaning up Redis state for job {job_id}: {e}")

# Cleanup NATS resources
# Cleanup NATS resources. Pass _logger through so TaskQueueManager can
# log final consumer stats and deletion events against the per-job logger.
async def cleanup():
async with TaskQueueManager() as manager:
async with TaskQueueManager(job_logger=_logger) as manager:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
return await manager.cleanup_job_resources(job_id)

try:
Expand Down Expand Up @@ -97,16 +103,26 @@ async def queue_all_images():
successful_queues = 0
failed_queues = 0

async with TaskQueueManager() as manager:
# Pass job.logger so stream/consumer setup and any publish failures
# appear in the UI job log (not just the module logger). Per-image
# success logs stay at module level so a 10k-image job doesn't drown
# the job log.
async with TaskQueueManager(job_logger=job.logger) as manager:
for image_pk, task in tasks:
try:
logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}")
logger.debug(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}")
success = await manager.publish_task(
job_id=job.pk,
data=task,
)
except Exception as e:
logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}")
# job.logger.error triggers a sync Django ORM save inside
# JobLogHandler.emit, which raises SynchronousOnlyOperation
# when called directly from the event loop. Bridge it so
# the line actually lands in job.logs.stdout.
await sync_to_async(job.logger.error)(
f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
success = False

if success:
Expand Down
208 changes: 168 additions & 40 deletions ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging

import nats
from asgiref.sync import sync_to_async
from django.conf import settings
from nats.js import JetStreamContext
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
Expand Down Expand Up @@ -54,21 +55,75 @@ class TaskQueueManager:
nats_url: NATS server URL. Falls back to settings.NATS_URL, then "nats://nats:4222".
max_ack_pending: Max unacknowledged messages per consumer. Falls back to
settings.NATS_MAX_ACK_PENDING, then 1000.
job_logger: Optional per-job logger. When set, lifecycle events (stream /
consumer create or reuse, cleanup stats, publish failures) are mirrored
to this logger in addition to the module logger, so they appear in the
job's own log stream as seen from the UI. Per-message and per-poll
events stay on the module logger only to avoid drowning large jobs.

Use as an async context manager:
async with TaskQueueManager() as manager:
async with TaskQueueManager(job_logger=job.logger) as manager:
await manager.publish_task(123, {'data': 'value'})
tasks = await manager.reserve_tasks(123, count=64)
await manager.acknowledge_task(tasks[0].reply_subject)
"""

def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = None):
def __init__(
self,
nats_url: str | None = None,
max_ack_pending: int | None = None,
job_logger: logging.Logger | None = None,
):
self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222")
self.max_ack_pending = (
max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 1000)
)
self.job_logger = job_logger
self.nc: nats.NATS | None = None
self.js: JetStreamContext | None = None
# Dedupe lifecycle log lines per manager session so a job that publishes
# hundreds of tasks doesn't emit hundreds of "reusing stream" messages.
self._streams_logged: set[int] = set()
self._consumers_logged: set[int] = set()

async def _log(self, level: int, msg: str) -> None:
"""Log to both the module logger and the job logger (if set).

Module logger fires synchronously (ops dashboards / stdout / New Relic
are unaffected). The job logger call is bridged through
``sync_to_async`` because Django's ``JobLogHandler`` does an ORM
``refresh_from_db`` + ``save`` on every emit — calling that directly
from the event loop raises ``SynchronousOnlyOperation`` and the log
line is silently dropped. The bridge offloads the handler work to a
thread so the line actually lands in ``job.logs.stdout``.

Exceptions from the job logger are swallowed so logging a lifecycle
event never breaks the actual NATS operation.
"""
logger.log(level, msg)
if self.job_logger is not None and self.job_logger is not logger:
try:
await sync_to_async(self.job_logger.log)(level, msg)
except Exception as e:
logger.warning(f"Failed to mirror log to job logger: {e}")

@staticmethod
def _format_consumer_stats(info) -> str:
"""Format ConsumerInfo into a compact stats string.

All nats-py ConsumerInfo fields are Optional, so defensive access is
required: this method renders missing values as '?'. Used for both
reuse-announcements and forensic cleanup lines.
"""
delivered = info.delivered.consumer_seq if info.delivered is not None else "?"
ack_floor = info.ack_floor.consumer_seq if info.ack_floor is not None else "?"
return (
f"delivered={delivered} "
f"ack_floor={ack_floor} "
f"num_pending={info.num_pending if info.num_pending is not None else '?'} "
f"num_ack_pending={info.num_ack_pending if info.num_ack_pending is not None else '?'} "
f"num_redelivered={info.num_redelivered if info.num_redelivered is not None else '?'}"
)

async def __aenter__(self):
"""Create connection on enter."""
Expand Down Expand Up @@ -127,27 +182,52 @@ async def _stream_exists(self, stream_name: str) -> bool:
return False

async def _ensure_stream(self, job_id: int):
"""Ensure stream exists for the given job."""
"""Ensure stream exists for the given job.

Logs a lifecycle line to both the module and job logger the first time it
sees a given job in this manager session (creation or reuse). Subsequent
calls in the same session are silent, so a job publishing N images doesn't
emit N log lines.
"""
if self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

if not await self._job_stream_exists(job_id):
stream_name = self._get_stream_name(job_id)
subject = self._get_subject(job_id)
logger.warning(f"Stream {stream_name} does not exist")
# Stream doesn't exist, create it
await asyncio.wait_for(
self.js.add_stream(
name=stream_name,
subjects=[subject],
max_age=86400, # 24 hours retention
),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.info(f"Created stream {stream_name}")
stream_name = self._get_stream_name(job_id)
subject = self._get_subject(job_id)

try:
info = await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT)
if job_id not in self._streams_logged:
await self._log(
logging.INFO,
f"Reusing NATS stream {stream_name} "
f"(messages={info.state.messages}, last_seq={info.state.last_seq})",
)
self._streams_logged.add(job_id)
return
except nats.js.errors.NotFoundError:
pass

await asyncio.wait_for(
self.js.add_stream(
name=stream_name,
subjects=[subject],
max_age=86400, # 24 hours retention
),
timeout=NATS_JETSTREAM_TIMEOUT,
)
await self._log(logging.INFO, f"Created NATS stream {stream_name}")
self._streams_logged.add(job_id)

async def _ensure_consumer(self, job_id: int):
"""Ensure consumer exists for the given job."""
"""Ensure consumer exists for the given job.

On first sight in this manager session (creation or reuse), emits a line
to both the module and job logger. On creation the line includes the
config snapshot (max_deliver, ack_wait, max_ack_pending, deliver_policy,
ack_policy) so forensic readers can see exactly what delivery semantics
were in effect.
"""
if self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

Expand All @@ -160,27 +240,43 @@ async def _ensure_consumer(self, job_id: int):
self.js.consumer_info(stream_name, consumer_name),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.debug(f"Consumer {consumer_name} already exists: {info}")
if job_id not in self._consumers_logged:
await self._log(
logging.INFO,
f"Reusing NATS consumer {consumer_name} ({self._format_consumer_stats(info)})",
)
self._consumers_logged.add(job_id)
return
except asyncio.TimeoutError:
raise # NATS unreachable — let caller handle it
except Exception:
# Consumer doesn't exist, create it
await asyncio.wait_for(
self.js.add_consumer(
stream=stream_name,
config=ConsumerConfig(
durable_name=consumer_name,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=TASK_TTR, # Visibility timeout (TTR)
max_deliver=5, # Max retry attempts
deliver_policy=DeliverPolicy.ALL,
max_ack_pending=self.max_ack_pending,
filter_subject=subject,
),
# Consumer doesn't exist, fall through to create it.
pass
Comment thread
mihow marked this conversation as resolved.

await asyncio.wait_for(
self.js.add_consumer(
stream=stream_name,
config=ConsumerConfig(
durable_name=consumer_name,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=TASK_TTR, # Visibility timeout (TTR)
max_deliver=5, # Max retry attempts
deliver_policy=DeliverPolicy.ALL,
max_ack_pending=self.max_ack_pending,
filter_subject=subject,
),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.info(f"Created consumer {consumer_name}")
),
timeout=NATS_JETSTREAM_TIMEOUT,
)
await self._log(
logging.INFO,
f"Created NATS consumer {consumer_name} "
f"(max_deliver=5, ack_wait={TASK_TTR}s, "
Comment thread
mihow marked this conversation as resolved.
Outdated
f"max_ack_pending={self.max_ack_pending}, "
f"deliver_policy={DeliverPolicy.ALL.value}, "
f"ack_policy={AckPolicy.EXPLICIT.value})",
)
self._consumers_logged.add(job_id)

async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool:
"""
Expand Down Expand Up @@ -212,7 +308,10 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool:
return True

except Exception as e:
logger.error(f"Failed to publish task to stream for job '{job_id}': {e}")
# Per-message success logs stay at module level (noise in 10k-image
# jobs), but a failure on even a single publish deserves to surface
# in the job log — otherwise the failure path is invisible to users.
await self._log(logging.ERROR, f"Failed to publish task to stream for job '{job_id}': {e}")
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very helpful!

return False

async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> list[PipelineProcessingTask]:
Expand Down Expand Up @@ -292,6 +391,31 @@ async def acknowledge_task(self, reply_subject: str) -> bool:
logger.error(f"Failed to acknowledge task: {e}")
return False

async def _log_final_consumer_stats(self, job_id: int) -> None:
"""Log one forensic line about the consumer state before deletion.

This is the single most useful line in a post-mortem: it tells you how
many messages were delivered, how many were acked, and how many were
redelivered before the consumer vanished. Failures here must NOT block
cleanup — if the consumer or stream is already gone, just skip it.
"""
if self.js is None:
return
stream_name = self._get_stream_name(job_id)
consumer_name = self._get_consumer_name(job_id)
try:
info = await asyncio.wait_for(
self.js.consumer_info(stream_name, consumer_name),
timeout=NATS_JETSTREAM_TIMEOUT,
)
except Exception as e:
logger.debug(f"Could not fetch consumer info for {consumer_name} before deletion: {e}")
return
await self._log(
logging.INFO,
f"Finalizing NATS consumer {consumer_name} before deletion " f"({self._format_consumer_stats(info)})",
)

async def delete_consumer(self, job_id: int) -> bool:
"""
Delete the consumer for a job.
Expand All @@ -313,10 +437,10 @@ async def delete_consumer(self, job_id: int) -> bool:
self.js.delete_consumer(stream_name, consumer_name),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'")
await self._log(logging.INFO, f"Deleted NATS consumer {consumer_name} for job '{job_id}'")
return True
except Exception as e:
logger.error(f"Failed to delete consumer for job '{job_id}': {e}")
await self._log(logging.ERROR, f"Failed to delete NATS consumer for job '{job_id}': {e}")
return False

async def delete_stream(self, job_id: int) -> bool:
Expand All @@ -339,10 +463,10 @@ async def delete_stream(self, job_id: int) -> bool:
self.js.delete_stream(stream_name),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.info(f"Deleted stream {stream_name} for job '{job_id}'")
await self._log(logging.INFO, f"Deleted NATS stream {stream_name} for job '{job_id}'")
return True
except Exception as e:
logger.error(f"Failed to delete stream for job '{job_id}': {e}")
await self._log(logging.ERROR, f"Failed to delete NATS stream for job '{job_id}': {e}")
return False

async def _setup_advisory_stream(self):
Expand Down Expand Up @@ -482,6 +606,10 @@ async def cleanup_job_resources(self, job_id: int) -> bool:
Returns:
bool: True if successful, False otherwise
"""
# Log a forensic snapshot of the consumer state BEFORE we destroy it.
# This is the highest-leverage line for post-mortem investigations.
await self._log_final_consumer_stats(job_id)

# Delete consumer first, then stream, then the durable DLQ advisory consumer
consumer_deleted = await self.delete_consumer(job_id)
stream_deleted = await self.delete_stream(job_id)
Expand Down
Loading
Loading