diff --git a/ami/base/cached_counts.py b/ami/base/cached_counts.py new file mode 100644 index 000000000..0faa203bb --- /dev/null +++ b/ami/base/cached_counts.py @@ -0,0 +1,60 @@ +"""Helpers for scheduling cached-count recomputes after a transaction commits. + +Wraps the per-connection dedup + ``transaction.on_commit`` plumbing that +``BaseModel.update_cached_counts`` and ``BaseQuerySet.update_cached_counts`` +build on. The actual recompute body lives in each model's +``update_calculated_fields(save=True)`` implementation; this module only +handles scheduling. + +Per-(model_label, pk) dedup means N writes affecting the same target row +collapse to one task, regardless of how many signal handlers fire in the +transaction. The dedup set lives on the active DB connection (thread-local +in Django's default setup) and is drained by a single ``on_commit`` hook. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from django.db import connection, transaction + +logger = logging.getLogger(__name__) + +_PENDING_ATTR = "_pending_cached_count_recomputes" + + +def schedule_recompute(model_label: str, pk: Any) -> None: + """Queue a ``(model_label, pk)`` for recompute at the next commit. + + The pending set lives on the active DB connection; ``transaction.on_commit`` + fires the flush. ``_flush_pending_recomputes`` is idempotent — the first + call drains the set; subsequent ones no-op — so we register on_commit on + every call. That keeps us correct across transaction rollbacks (which + discard registered on_commits but leave attributes on ``connection`` + untouched, e.g. between a rolled-back ``TestCase`` and a fresh + ``TransactionTestCase``). + + Outside an atomic block, ``on_commit`` fires synchronously at + registration time — so the ``add`` below must precede the + ``transaction.on_commit`` call or the flush sees an empty set. + """ + pending: set[tuple[str, Any]] | None = getattr(connection, _PENDING_ATTR, None) + if pending is None: + pending = set() + setattr(connection, _PENDING_ATTR, pending) + pending.add((model_label, pk)) + transaction.on_commit(_flush_pending_recomputes) + + +def _flush_pending_recomputes() -> None: + """Drain the per-connection dedup set; dispatch one task per ``(model, pk)``.""" + from ami.main.tasks import recompute_cached_counts_task + + pending: set[tuple[str, Any]] = getattr(connection, _PENDING_ATTR, set()) + try: + delattr(connection, _PENDING_ATTR) + except AttributeError: + pass + for model_label, pk in pending: + recompute_cached_counts_task.delay(model_label, pk) diff --git a/ami/base/models.py b/ami/base/models.py index 2f245b745..62e09d6e4 100644 --- a/ami/base/models.py +++ b/ami/base/models.py @@ -7,6 +7,23 @@ from ami.users.models import User +class CachedCountField(models.IntegerField): + """Denormalized count of related rows. + + Marker subclass so cached aggregate columns can be discovered via + ``Model._meta.get_fields()`` + ``isinstance(f, CachedCountField)`` by + future cross-cutting tasks (admin display, periodic drift + reconciliation). Column type is unchanged + from ``IntegerField`` — the AlterField migrations that introduce this + subclass are no-op SQL. Mixing ``CachedCountField`` and plain + ``IntegerField`` on the same model is fine, but a future contributor + adding a non-cached IntegerField next to a cached one will see an + AlterField in their migration; that's expected, not a bug. + """ + + description = "Cached count of related rows" + + def has_one_to_many_project_relation(model: type[models.Model]) -> bool: """ Returns True if the model has any ForeignKey or OneToOneField relationship to Project. @@ -40,6 +57,29 @@ def has_many_to_many_project_relation(model: type[models.Model]) -> bool: class BaseQuerySet(QuerySet): + def update_cached_counts(self, run_async: bool = True) -> None: + """Recompute cached count columns for every row in the queryset. + + With ``run_async=True`` (default), each row is queued for recompute + via ``ami.base.cached_counts.schedule_recompute`` and dispatched as + a single ``recompute_cached_counts_task`` per ``(model, pk)`` after + the surrounding transaction commits. Repeated calls within the same + transaction dedupe through a per-connection set. + + With ``run_async=False``, each row is loaded and recomputed inline. + Suitable for Celery-worker contexts (reconcile task, ML pipeline + finalize) where the caller has already taken the latency hit. + """ + from ami.base.cached_counts import schedule_recompute + + model_label = self.model._meta.label + for pk in self.values_list("pk", flat=True): + if run_async: + schedule_recompute(model_label, pk) + else: + instance = self.model.objects.get(pk=pk) + instance.update_calculated_fields(save=True) + def visible_for_user(self, user: User | AnonymousUser) -> QuerySet: """ Filter queryset to include only objects whose related draft projects @@ -166,6 +206,24 @@ def update_calculated_fields(self, *args, **kwargs): """Update calculated fields specific to each model.""" pass + def update_cached_counts(self, run_async: bool = True) -> None: + """Recompute this row's cached count columns. + + With ``run_async=True`` (default), schedule a Celery task to run + after the surrounding transaction commits. Per-(model, pk) dedup + on the active DB connection collapses repeated calls within the + same transaction into a single task. + + With ``run_async=False``, recompute inline by calling + ``update_calculated_fields(save=True)`` directly. + """ + from ami.base.cached_counts import schedule_recompute + + if run_async: + schedule_recompute(self._meta.label, self.pk) + return + self.update_calculated_fields(save=True) + def _get_object_perms(self, user): """ Get the object-level permissions for the user on this instance. diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index 354df5459..dff62167a 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -1254,6 +1254,12 @@ class Meta: "created_at", "updated_at", ] + # Denormalized columns kept in sync by signals; never client-writable. + read_only_fields = [ + "source_images_count", + "source_images_with_detections_count", + "source_images_processed_count", + ] def get_permissions(self, instance, instance_data): request: Request = self.context["request"] diff --git a/ami/main/api/views.py b/ami/main/api/views.py index c4ca76da8..84981fb42 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -712,13 +712,7 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin): Endpoint for viewing capture sets or samples of captures. """ - queryset = ( - SourceImageCollection.objects.all() - .with_source_images_count() # type: ignore - .with_source_images_with_detections_count() - .with_source_images_processed_count() - .prefetch_related("jobs") - ) + queryset = SourceImageCollection.objects.all().prefetch_related("jobs") serializer_class = SourceImageCollectionSerializer permission_classes = [ ObjectPermission, diff --git a/ami/main/migrations/0085_denormalize_sourceimagecollection_counts.py b/ami/main/migrations/0085_denormalize_sourceimagecollection_counts.py new file mode 100644 index 000000000..503bdf1b3 --- /dev/null +++ b/ami/main/migrations/0085_denormalize_sourceimagecollection_counts.py @@ -0,0 +1,39 @@ +""" +Add three denormalized count columns to ``SourceImageCollection`` so the list +endpoint reads them in O(1) instead of running 3 correlated count subqueries +per row. + +Schema only. The backfill runs in the separate, non-atomic, re-runnable +migration ``0086_backfill_sourceimagecollection_counts`` so an interrupted +backfill on production-sized data cannot leave the schema half-applied +(columns added but migration unrecorded -> retry fails on duplicate column). + +``AddField`` with a constant ``default`` is a metadata-only operation on +PostgreSQL 11+, so this is safe to run atomically even on large tables. +""" + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0084_revoke_delete_job_from_roles"), + ] + + operations = [ + migrations.AddField( + model_name="sourceimagecollection", + name="source_images_count", + field=models.IntegerField(default=0), + ), + migrations.AddField( + model_name="sourceimagecollection", + name="source_images_with_detections_count", + field=models.IntegerField(default=0), + ), + migrations.AddField( + model_name="sourceimagecollection", + name="source_images_processed_count", + field=models.IntegerField(default=0), + ), + ] diff --git a/ami/main/migrations/0086_backfill_sourceimagecollection_counts.py b/ami/main/migrations/0086_backfill_sourceimagecollection_counts.py new file mode 100644 index 000000000..e40740186 --- /dev/null +++ b/ami/main/migrations/0086_backfill_sourceimagecollection_counts.py @@ -0,0 +1,65 @@ +""" +Backfill the denormalized ``SourceImageCollection`` count columns added in +0085. + +Split from the schema migration on purpose: this is the slow step on +production-sized M2M tables. ``atomic = False`` lets the UPDATE run outside a +single transaction, and the UPDATE writes absolute computed values (not +deltas) so it is idempotent — safe to re-run if interrupted. Collections with +no images keep the column ``default=0`` from 0085 (the GROUP BY only emits +rows for collections that have images). + +``with_det`` checks for a valid (non-null / non-empty) detection bbox to match +the runtime ``NULL_DETECTIONS_FILTER`` semantics in ``ami/main/models.py``. +""" + +from django.db import migrations + + +def backfill_counts(apps, schema_editor): + schema_editor.execute( + """ + UPDATE main_sourceimagecollection sc + SET source_images_count = c.total, + source_images_processed_count = c.processed, + source_images_with_detections_count = c.with_det + FROM ( + SELECT msci.sourceimagecollection_id AS coll_id, + COUNT(*) AS total, + COUNT(*) FILTER ( + WHERE EXISTS ( + SELECT 1 FROM main_detection d + WHERE d.source_image_id = si.id + ) + ) AS processed, + COUNT(*) FILTER ( + WHERE EXISTS ( + SELECT 1 FROM main_detection d + WHERE d.source_image_id = si.id + AND d.bbox IS NOT NULL + AND d.bbox::text <> '[]' + ) + ) AS with_det + FROM main_sourceimagecollection_images msci + INNER JOIN main_sourceimage si ON si.id = msci.sourceimage_id + GROUP BY msci.sourceimagecollection_id + ) c + WHERE sc.id = c.coll_id; + """ + ) + + +def reverse_noop(apps, schema_editor): + pass + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ("main", "0085_denormalize_sourceimagecollection_counts"), + ] + + operations = [ + migrations.RunPython(backfill_counts, reverse_noop), + ] diff --git a/ami/main/migrations/0087_use_cached_count_field.py b/ami/main/migrations/0087_use_cached_count_field.py new file mode 100644 index 000000000..5f568da60 --- /dev/null +++ b/ami/main/migrations/0087_use_cached_count_field.py @@ -0,0 +1,73 @@ +# Generated by Django 4.2.10 on 2026-05-13 20:05 + +import ami.base.models +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0086_backfill_sourceimagecollection_counts"), + ] + + operations = [ + migrations.AlterField( + model_name="deployment", + name="captures_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="deployment", + name="detections_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="deployment", + name="events_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="deployment", + name="occurrences_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="deployment", + name="taxa_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="event", + name="captures_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="event", + name="detections_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="event", + name="occurrences_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="sourceimage", + name="detections_count", + field=ami.base.models.CachedCountField(blank=True, null=True), + ), + migrations.AlterField( + model_name="sourceimagecollection", + name="source_images_count", + field=ami.base.models.CachedCountField(default=0), + ), + migrations.AlterField( + model_name="sourceimagecollection", + name="source_images_processed_count", + field=ami.base.models.CachedCountField(default=0), + ), + migrations.AlterField( + model_name="sourceimagecollection", + name="source_images_with_detections_count", + field=ami.base.models.CachedCountField(default=0), + ), + ] diff --git a/ami/main/models.py b/ami/main/models.py index b30b4e645..1d46d4cc5 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -32,7 +32,7 @@ import ami.tasks import ami.utils from ami.base.fields import DateStringField -from ami.base.models import BaseModel, BaseQuerySet +from ami.base.models import BaseModel, BaseQuerySet, CachedCountField from ami.main import charts from ami.main.models_future.filters import ( build_occurrence_default_filters_q, @@ -754,11 +754,11 @@ class Deployment(BaseModel): # data_source_last_check_notes = models.TextField(max_length=255, blank=True, null=True) # Pre-calculated values - events_count = models.IntegerField(blank=True, null=True) - occurrences_count = models.IntegerField(blank=True, null=True) - captures_count = models.IntegerField(blank=True, null=True) - detections_count = models.IntegerField(blank=True, null=True) - taxa_count = models.IntegerField(blank=True, null=True) + events_count = CachedCountField(blank=True, null=True) + occurrences_count = CachedCountField(blank=True, null=True) + captures_count = CachedCountField(blank=True, null=True) + detections_count = CachedCountField(blank=True, null=True) + taxa_count = CachedCountField(blank=True, null=True) first_capture_timestamp = models.DateTimeField(blank=True, null=True) last_capture_timestamp = models.DateTimeField(blank=True, null=True) @@ -1155,9 +1155,9 @@ class Event(BaseModel): occurrences: models.QuerySet["Occurrence"] # Pre-calculated values - captures_count = models.IntegerField(blank=True, null=True) - detections_count = models.IntegerField(blank=True, null=True) - occurrences_count = models.IntegerField(blank=True, null=True) + captures_count = CachedCountField(blank=True, null=True) + detections_count = CachedCountField(blank=True, null=True) + occurrences_count = CachedCountField(blank=True, null=True) calculated_fields_updated_at = models.DateTimeField(blank=True, null=True) class Meta: @@ -1942,7 +1942,7 @@ class SourceImage(BaseModel): test_image = models.BooleanField(default=False) # Precaclulated values - detections_count = models.IntegerField(null=True, blank=True) + detections_count = CachedCountField(null=True, blank=True) project = models.ForeignKey(Project, on_delete=models.SET_NULL, null=True, related_name="captures") deployment = models.ForeignKey(Deployment, on_delete=models.SET_NULL, null=True, related_name="captures") @@ -4093,32 +4093,6 @@ def html(self) -> str: class SourceImageCollectionQuerySet(BaseQuerySet): - def with_source_images_count(self): - return self.annotate( - source_images_count=models.Count( - "images", - distinct=True, - ) - ) - - def with_source_images_with_detections_count(self): - return self.annotate( - source_images_with_detections_count=models.Count( - "images", - filter=(~models.Q(images__detections__bbox__isnull=True) & ~models.Q(images__detections__bbox=[])), - distinct=True, - ) - ) - - def with_source_images_processed_count(self): - return self.annotate( - source_images_processed_count=models.Count( - "images", - filter=models.Q(images__detections__isnull=False), - distinct=True, - ) - ) - def with_source_images_processed_by_algorithm_count(self, algorithm_id: int): return self.annotate( source_images_processed_by_algorithm_count=models.Count( @@ -4205,6 +4179,12 @@ class SourceImageCollection(BaseModel): default=dict, ) + # Denormalized counts. Kept in sync via m2m_changed and pipeline-completion + # hooks. Reads are O(1). + source_images_count = CachedCountField(default=0) + source_images_with_detections_count = CachedCountField(default=0) + source_images_processed_count = CachedCountField(default=0) + objects = SourceImageCollectionManager() jobs: models.QuerySet["Job"] @@ -4219,19 +4199,6 @@ def infer_dataset_type(self): def dataset_type(self): return self.infer_dataset_type() - def source_images_count(self) -> int | None: - # This should always be pre-populated using queryset annotations - # return self.images.count() - return None - - def source_images_with_detections_count(self) -> int | None: - # This should always be pre-populated using queryset annotations - return None - - def source_images_processed_count(self) -> int | None: - # This should always be pre-populated using queryset annotations - return None - def occurrences_count(self) -> int | None: # This should always be pre-populated using queryset annotations return None @@ -4240,6 +4207,39 @@ def taxa_count(self) -> int | None: # This should always be pre-populated using queryset annotations return None + def get_source_image_counts(self) -> dict[str, int]: + """Return the 3 source-image counts as a dict. Single aggregate query; does not write to the DB.""" + valid_det = Detection.objects.filter(source_image=models.OuterRef("pk")).exclude(NULL_DETECTIONS_FILTER) + any_det = Detection.objects.filter(source_image=models.OuterRef("pk")) + counts = self.images.annotate( + _has_any_det=Exists(any_det), + _has_valid_det=Exists(valid_det), + ).aggregate( + source_images_count=models.Count("id"), + source_images_processed_count=models.Count("id", filter=models.Q(_has_any_det=True)), + source_images_with_detections_count=models.Count("id", filter=models.Q(_has_valid_det=True)), + ) + return counts + + def update_calculated_fields(self, save: bool = False) -> None: + """Recompute the 3 denormalized source-image count columns. + + Persists via ``.filter(pk=).update(**counts)`` rather than ``.save()`` + — cached-count refreshes shouldn't bump ``updated_at`` (semantically the + entity hasn't been modified) and shouldn't re-fire ``post_save``, which + on this model would re-enter ``m2m_changed`` if the handler later does + anything with ``self.images``. Deployment / Event / SourceImage take a + different path (``self.save(update_calculated_fields=False)``) because + their ``update_calculated_fields`` also writes non-cached fields that + downstream save-handlers expect to see updated. + """ + counts = self.get_source_image_counts() + self.source_images_count = counts["source_images_count"] + self.source_images_processed_count = counts["source_images_processed_count"] + self.source_images_with_detections_count = counts["source_images_with_detections_count"] + if save: + SourceImageCollection.objects.filter(pk=self.pk).update(**counts) + def get_queryset( self, *args, diff --git a/ami/main/signals.py b/ami/main/signals.py index e36e41937..6bf4f0ea4 100644 --- a/ami/main/signals.py +++ b/ami/main/signals.py @@ -2,11 +2,11 @@ from django.contrib.auth.models import Group from django.db import transaction -from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save +from django.db.models.signals import m2m_changed, post_delete, post_save, pre_delete, pre_save from django.dispatch import receiver from guardian.shortcuts import assign_perm -from ami.main.models import Project +from ami.main.models import Detection, Project, SourceImageCollection from ami.main.tasks import refresh_project_cached_counts from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project @@ -197,3 +197,40 @@ def exclude_taxa_updated(sender, instance: Project, action, **kwargs): if action in ["post_add", "post_remove", "post_clear"]: logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})") refresh_cached_counts_for_project(instance) + + +# ============================================================================ +# SourceImageCollection Denormalized Counts +# ============================================================================ +# +# Detection writes can fan out to many collections (one SourceImage may belong +# to multiple collections). The queryset's ``update_cached_counts()`` method +# dedupes per-(model, pk) across the transaction via a per-connection set, so +# a 10k-row pipeline save fires the recompute task at most once per affected +# collection instead of once per detection. Bulk write paths that bypass +# signals (``bulk_create``, ``bulk_update``, raw SQL) still drift the cached +# counts; ``pipeline.save_results()`` explicitly recomputes for the ML path. +# Generic periodic drift reconciliation across all CachedCountField models is +# tracked as a follow-up (see docs/claude/planning/cached-counts-reconcile-followup.md). + + +@receiver(m2m_changed, sender=SourceImageCollection.images.through) +def update_collection_counts_on_m2m(sender, instance, action, **kwargs): + """Recompute denormalized counts when images are added to or removed from a collection.""" + if action in ("post_add", "post_remove", "post_clear"): + instance.update_cached_counts() + + +@receiver(post_save, sender=Detection) +@receiver(post_delete, sender=Detection) +def update_collection_counts_on_detection_change(sender, instance, **kwargs): + """Schedule a collection-counts refresh for every collection containing the affected SourceImage. + + The queryset method's per-(model, pk) dedup means tight per-row Detection + write loops fan out to at most one task per affected collection. + ``bulk_create`` / ``bulk_update`` skip signals entirely — those rely on + the periodic reconciliation task to repair drift. + """ + if not instance.source_image_id: + return + SourceImageCollection.objects.filter(images__id=instance.source_image_id).update_cached_counts() diff --git a/ami/main/tasks.py b/ami/main/tasks.py index 16f927a3f..45637fb97 100644 --- a/ami/main/tasks.py +++ b/ami/main/tasks.py @@ -1,17 +1,43 @@ import logging +from django.apps import apps + from config import celery_app logger = logging.getLogger(__name__) +@celery_app.task(ignore_result=True) +def recompute_cached_counts_task(model_label: str, pk: int) -> None: + """Recompute one row's cached count columns. + + Dispatched by ``ami.base.cached_counts._flush_pending_recomputes`` after + a transaction commits. Generic across every model that declares + ``CachedCountField`` columns and implements ``update_calculated_fields``. + + Silent on missing rows: the row may have been deleted between when the + recompute was queued and when the task runs. + """ + model = apps.get_model(model_label) + try: + instance = model.objects.get(pk=pk) + except model.DoesNotExist: + logger.debug("recompute_cached_counts_task: %s pk=%s not found, skipping", model_label, pk) + return + instance.update_calculated_fields(save=True) + + @celery_app.task(ignore_result=True) def refresh_project_cached_counts(project_id: int) -> None: - """Refresh cached counts for all Events and Deployments in a project. + """Refresh cached counts on every Event, Deployment, and SourceImage in a project. - Dispatched from signals on ``Project.default_filters_*`` changes. The work - fans out to every Event and Deployment in the project, so it must not run - inline in the request/save path. + Dispatched from signals on ``Project.default_filters_*`` changes. The + cascade can touch tens of thousands of rows for a large project, so we + do the work inline in this single Celery task rather than queueing one + recompute task per row — that would flood the broker on a single filter + change. ``Project.update_related_calculated_fields()`` keeps the bulk + subquery UPDATE for ``SourceImage.detections_count`` while looping Events + and Deployments row-by-row. """ from ami.main.models import Project diff --git a/ami/main/tests.py b/ami/main/tests.py index c31d56faa..67572415d 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -6,7 +6,7 @@ from django.contrib.auth.models import AnonymousUser from django.core.files.uploadedfile import SimpleUploadedFile from django.db import connection, models -from django.test import TestCase, override_settings +from django.test import TestCase, TransactionTestCase, override_settings from guardian.shortcuts import assign_perm, get_perms, remove_perm from PIL import Image from rest_framework import status @@ -3226,12 +3226,11 @@ def test_list_response_shape_has_no_lazy_loads(self): for key in ("id", "url", "size_display", "deployment", "event", "detections_count", "path"): self.assertIn(key, row, f"missing field {key!r} in list response") self.assertIsNotNone(row["deployment"]["name"]) - # No lazy-load queries should fire after the main list SELECT. - # 1 list select + 1 detection prefetch (no, not in this call) + savepoints. + # After the main list SELECT, only prefetch + savepoint queries should fire; no per-row lazy loads. self.assertLessEqual( len(ctx.captured_queries), 6, - f"Unexpected extra queries — likely lazy-load from deferred field: {len(ctx.captured_queries)}", + f"Unexpected extra queries (likely lazy-load from a deferred field): {len(ctx.captured_queries)}", ) @@ -3275,6 +3274,201 @@ def test_list_query_count_does_not_scale_with_page_size(self): self.assertLessEqual(large, small + 5, f"Taxon list scaling: {small} -> {large} (likely N+1)") +@override_settings(CACHALOT_ENABLED=False) +class TestSourceImageCollectionListQueryCount(APITestCase): + """Audit SourceImageCollectionViewSet.list query counts. + + The three source-image counts are denormalized as columns on + SourceImageCollection (see migration 0085). The viewset no longer needs + per-row count subqueries, so list, with_counts, and ordering paths all run + against the column directly. + + Cachalot disabled so we measure cold query count, not warm cache. + """ + + def setUp(self): + self.project, self.deployment = setup_test_project() + create_taxa(self.project) + create_captures(deployment=self.deployment, num_nights=1, images_per_night=25) + create_occurrences(deployment=self.deployment, num=25, determination_score=0.9) + + images = list(SourceImage.objects.filter(deployment=self.deployment)) + # 30 collections so `limit=25` exercises a real page boundary; per-row + # subquery scaling regressions only show up once the page has >1 row. + for i in range(30): + c = SourceImageCollection.objects.create( + name=f"qcount-collection-{i}", + project=self.project, + method="manual", + kwargs={"image_ids": [img.pk for img in images]}, + ) + c.images.set(images) + + self.project.default_filters_score_threshold = 0.0 + self.project.save() + + self.user = User.objects.create_user( + email="qcount-collection@insectai.org", is_staff=False, is_superuser=False + ) + self.client.force_authenticate(user=self.user) + + def _list_query_count(self, url: str) -> int: + from django.core.cache import caches + from django.test.utils import CaptureQueriesContext + + caches["default"].clear() + with CaptureQueriesContext(connection) as ctx: + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK, res.content) + return len(ctx.captured_queries) + + def test_list_query_count_does_not_scale_with_page_size(self): + small = self._list_query_count(f"/api/v2/captures/collections/?project_id={self.project.pk}&limit=1") + large = self._list_query_count(f"/api/v2/captures/collections/?project_id={self.project.pk}&limit=25") + print(f"\n[AUDIT] Collection list: limit=1 -> {small}q, limit=25 -> {large}q") + self.assertLessEqual(large, small + 2, f"Collection list scaling: {small} -> {large} (likely N+1)") + + def test_list_query_count_with_counts(self): + url = f"/api/v2/captures/collections/?project_id={self.project.pk}&with_counts=true&limit=25" + # warmups equalise pool state / auth + self.client.get(url) + self.client.get(url) + count = self._list_query_count(url) + print(f"\n[AUDIT] Collection list with_counts=true limit=25 -> {count}q") + # 3 source-image counts now read from columns; with_counts only adds the + # occurrences/taxa subquery annotations. + self.assertLessEqual(count, 10, f"Collection list with_counts too many queries: {count}") + + def test_list_query_count_ordering_by_annotated_count(self): + url = f"/api/v2/captures/collections/?project_id={self.project.pk}" f"&limit=25&ordering=-source_images_count" + self.client.get(url) + self.client.get(url) + count = self._list_query_count(url) + print(f"\n[AUDIT] Collection list ordered by source_images_count limit=25 -> {count}q") + # Sort uses the cached column directly, so no extra subquery is added. + self.assertLessEqual(count, 10, f"Collection list ordered by count too many queries: {count}") + + +@override_settings(CELERY_TASK_ALWAYS_EAGER=True, CELERY_TASK_EAGER_PROPAGATES=True) +class TestSourceImageCollectionCountsDenormalize(TransactionTestCase): + """Verify denormalized count columns stay in sync via signals and bulk hooks. + + Uses ``TransactionTestCase`` + eager Celery so ``transaction.on_commit`` + callbacks (registered by the signal handlers) actually fire inside the + test body and the dispatched ``recompute_cached_counts_task`` runs inline. + """ + + def setUp(self): + self.project, self.deployment = setup_test_project() + create_taxa(self.project) + create_captures(deployment=self.deployment, num_nights=1, images_per_night=5) + self.images = list(SourceImage.objects.filter(deployment=self.deployment)) + self.collection = SourceImageCollection.objects.create( + name="denorm-test", + project=self.project, + method="manual", + kwargs={"image_ids": [img.pk for img in self.images]}, + ) + + def _refresh(self): + self.collection.refresh_from_db() + + def test_count_updates_on_image_add(self): + self.collection.images.set(self.images) + self._refresh() + self.assertEqual(self.collection.source_images_count, len(self.images)) + + def test_count_decrements_on_image_remove(self): + self.collection.images.set(self.images) + self.collection.images.remove(self.images[0]) + self._refresh() + self.assertEqual(self.collection.source_images_count, len(self.images) - 1) + + def test_with_detections_count_updates_on_detection_create(self): + self.collection.images.set(self.images) + Detection.objects.create( + source_image=self.images[0], + timestamp=self.images[0].timestamp, + bbox=[10, 10, 20, 20], + path="detections/d1.jpg", + ) + self._refresh() + self.assertEqual(self.collection.source_images_with_detections_count, 1) + self.assertEqual(self.collection.source_images_processed_count, 1) + + def test_with_detections_count_decrements_on_detection_delete(self): + self.collection.images.set(self.images) + det = Detection.objects.create( + source_image=self.images[0], + timestamp=self.images[0].timestamp, + bbox=[10, 10, 20, 20], + path="detections/d1.jpg", + ) + det.delete() + self._refresh() + self.assertEqual(self.collection.source_images_with_detections_count, 0) + self.assertEqual(self.collection.source_images_processed_count, 0) + + def test_null_bbox_detection_processed_but_no_with_detections(self): + """A null-bbox detection marks the image as processed but not 'with detections'.""" + self.collection.images.set(self.images) + Detection.objects.create( + source_image=self.images[0], + timestamp=self.images[0].timestamp, + bbox=None, + path="detections/null.jpg", + ) + self._refresh() + self.assertEqual(self.collection.source_images_processed_count, 1) + self.assertEqual(self.collection.source_images_with_detections_count, 0) + + def test_update_calculated_fields_recomputes_from_scratch(self): + self.collection.images.set(self.images) + Detection.objects.create( + source_image=self.images[0], + timestamp=self.images[0].timestamp, + bbox=[10, 10, 20, 20], + path="detections/d1.jpg", + ) + SourceImageCollection.objects.filter(pk=self.collection.pk).update( + source_images_count=999, + source_images_processed_count=999, + source_images_with_detections_count=999, + ) + self.collection.refresh_from_db() + self.collection.update_calculated_fields(save=True) + self._refresh() + self.assertEqual(self.collection.source_images_count, len(self.images)) + self.assertEqual(self.collection.source_images_with_detections_count, 1) + + def test_get_source_image_counts_returns_dict_without_writes(self): + self.collection.images.set(self.images) + Detection.objects.create( + source_image=self.images[0], + timestamp=self.images[0].timestamp, + bbox=[10, 10, 20, 20], + path="detections/d1.jpg", + ) + SourceImageCollection.objects.filter(pk=self.collection.pk).update( + source_images_count=0, + source_images_processed_count=0, + source_images_with_detections_count=0, + ) + self.collection.refresh_from_db() + counts = self.collection.get_source_image_counts() + self.assertEqual( + counts, + { + "source_images_count": len(self.images), + "source_images_processed_count": 1, + "source_images_with_detections_count": 1, + }, + ) + # Confirm the DB row was not updated. + self.collection.refresh_from_db() + self.assertEqual(self.collection.source_images_count, 0) + + class TestProjectDefaultTaxaFilter(APITestCase): """ Tests for project default taxa filtering (include/exclude lists). diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index c259e4aea..0c6c1dbb3 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -999,6 +999,11 @@ def save_results( for deployment in Deployment.objects.filter(pk__in=deployment_ids): deployment.update_calculated_fields(save=True) + # bulk_create above skips Detection signals; refresh affected collections explicitly. + source_image_ids = [img.pk for img in source_images] + for collection in SourceImageCollection.objects.filter(images__id__in=source_image_ids).distinct(): + collection.update_calculated_fields(save=True) + total_time = time.time() - start_time job_logger.info(f"Saved results from pipeline {pipeline} in {total_time:.2f} seconds") diff --git a/docs/claude/planning/cached-counts-reconcile-followup.md b/docs/claude/planning/cached-counts-reconcile-followup.md new file mode 100644 index 000000000..ec4dc3c11 --- /dev/null +++ b/docs/claude/planning/cached-counts-reconcile-followup.md @@ -0,0 +1,148 @@ +# Follow-up: CachedCountField drift reconciliation + periodic task + dashboard + +**Status:** planned (split out of PR #1301 on 2026-05-15) +**Depends on:** PR #1301 (denormalized `SourceImageCollection` counts + `CachedCountField` marker) merged. + +## Why this is a separate PR + +#1301 ships the perf win (denormalized columns, signals, ML-path recompute, +`CachedCountField` marker, per-connection dedup scheduler). The generic +drift-reconciliation layer was removed from #1301 because it was not +production-wired and shipping it half-done is worse than not shipping it: + +- The reconcile Celery task had **no beat schedule** — `CELERY_BEAT_SCHEDULER` + is `django_celery_beat.schedulers:DatabaseScheduler` (`config/settings/base.py:413`), + so periodic tasks live in the DB. #1301 added no `PeriodicTask` registration, + so the "safety net" never actually ran. +- `reconcile_cached_counts` counted `checked` only for rows that **already + drifted**, contradicting the `IntegrityCheckResult` contract + (`ami/main/checks/schemas.py`: `checked` = rows inspected). A clean sweep + reported `checked=0`, indistinguishable from "didn't run". +- Task default was `project_id=None, dry_run=False` → a full-table, + repair-mode, per-row-subquery sweep across every `CachedCountField` model + (incl. `SourceImage`, millions of rows). Dangerous default if naively + scheduled. + +Removed from #1301 (recover from git history at branch +`perf/sourceimagecollection-cached-counts` pre-2026-05-15): +- `ami/main/checks/cached_counts.py` (`discover_cached_count_fields`, + `find_stale_cached_counts`, `reconcile_cached_counts`) +- `reconcile_cached_counts_task` in `ami/main/tasks.py` +- `TestCachedCountsIntegrityCheck` in `ami/main/tests.py` + +The `CachedCountField` marker, `discover`-via-`_meta.get_fields()` approach, +and the design doc (`docs/superpowers/specs/2026-05-14-cached-counts-update-method-design.md`) +stay in #1301 — the follow-up builds on the marker. + +## Scope of the follow-up + +1. **Reconcile module** — restore `find_stale_cached_counts` / + `reconcile_cached_counts`. Fix `checked` to increment per row inspected + (in the iteration, not the drift branch). Keep `dry_run` and `project_id` + scoping. +2. **Safe task defaults** — `reconcile_cached_counts_task` defaults to + `dry_run=True`; repair mode must be explicit. Require either `project_id` + or an explicit `model` for the repair path; refuse a full-table unscoped + repair without an explicit `force=True`. +3. **Periodic task registration** — data migration creating the + `django_celery_beat` `PeriodicTask` (dry-run mode, reasonable cadence, + per-project fan-out rather than one unscoped sweep). Document how to + enable repair mode per environment. +4. **Surface results** — dashboard / log destination for reconcile output + (checked / fixed / drift detail). Decide: admin page, structured logs to + the existing logging sink, or a lightweight status model. Drift events + should be visible without grepping worker logs. +5. **Tests** — restore the integrity-check tests; add one asserting + `checked` reflects rows inspected on a no-drift sweep (the bug the old + `test_reconcile_no_drift_returns_zero_checked` baked in). + +## Open questions + +- Cadence + scoping: per-project nightly vs. one global weekly sweep. Per-row + subquery recompute on `SourceImage` at prod scale is expensive — likely + needs the bulk-subquery UPDATE path (`Project.update_related_calculated_fields`) + rather than row-by-row `update_calculated_fields` for the big models. +- Whether reconcile should auto-repair or only alert + require a manual + trigger for repair (safer; drift usually signals a missing signal/hook + that should be fixed at the source, not papered over). + +## Reconciler compute strategy: read-only vs upsert vs invalidate + +We already have the read/write split on the eager path: +`SourceImageCollection.get_source_image_counts()` is pure-compute (one +aggregate, no writes); `update_calculated_fields(save=True)` is the +side-effecting upsert. That split is correct for the **signal** path but the +pure-compute read is the wrong primitive for a **sweep** — the removed +reconciler iterated rows calling it via `.iterator()`, which is N subquery +aggregates. Three named approaches, ranked: + +1. **Set-based diff (eager fields, chosen).** One GROUP BY producing the true + counts for every row in a single pass — this query already exists as the + `0086_backfill_sourceimagecollection_counts` SQL. Reconcile = run it, + compare to stored columns (`WHERE sc.x IS DISTINCT FROM c.x` to detect, + `UPDATE ... WHERE` to repair only divergent rows). O(passes), not O(rows). + Precedent: `Project.update_related_calculated_fields()` keeps a + bulk-subquery UPDATE for `SourceImage.detections_count` rather than + looping. **Implication:** the per-model "true value" query is the single + source of truth; backfill, reconcile, and (where cheap) the signal path + should all derive from it, closing the migration/runtime predicate-drift + gap flagged in the takeaway review. + +2. **Lazy invalidation (expensive fields, considered — see below).** Mark + stale on a write criterion; recompute on read or on a prioritized sweep. + Named pattern: *write-invalidate* / *cache-aside* with *TTL* or + *generation/version stamping*. Not warranted for the cheap count columns + (eager signal recompute is one aggregate per affected row), but it is the + right path for future expensive caches. + +3. **Generic per-row loop (diagnostic only).** Lowest-common-denominator, + works for any `CachedCountField` model, slow at scale. Survives only as a + scoped / `dry_run` diagnostic, never an unscoped repair sweep. + +## Design space for future expensive cached fields (considered, not in scope) + +Counts are cheap to recompute eagerly. Some cached values will not be — e.g. +a stored `best_machine_prediction_score` / `best_machine_prediction_taxon_id` +on `Occurrence` (today a queryset annotation, `with_best_machine_prediction`, +not a column). When a cached value is expensive, eager write-through stops +being viable and the field needs a **freshness signal**. Named options: + +- **NULL-sentinel staleness.** Nullable column where `NULL` = "not computed / + stale, recompute before trusting". Zero extra schema; this is *already* the + de-facto contract for `Deployment.*_count` / `Event.*_count` + (`blank=True, null=True`). Limitation: conflates "computed and genuinely + empty" with "stale". Fine for counts (`NULL` ≠ `0`) and for FK/score caches + (`NULL` = unknown). The "stale bit = set it to None" intuition is exactly + this pattern. + +- **Freshness-timestamp companion.** A sibling `_computed_at` + datetime; staleness = `computed_at IS NULL OR computed_at <= source_changed_at` + (TTL or watermark). Precedent already in the codebase: + `Event.calculated_fields_updated_at` (`models.py:1161`) gates recompute via + `Q(calculated_fields_updated_at__isnull=True) | Q(...__lte=last_updated)` + (`models.py:1346`). Preferred for expensive fields: enables TTL, + oldest-first reconcile prioritization, and observability ("how stale is + this?"). This is the "value + stale bit, two-part field" idea — the second + part is a freshness timestamp, not a boolean. + +- **Aggregates / rollup table (or materialized view).** For very expensive + cross-table rollups, store the value out of the hot row entirely: a FK from + the entity to a summary row, or a DB **materialized view** refreshed on a + schedule. Data-warehouse name: *aggregate table* / *summary table*; the + DB-maintained variant is a *materialized view* (`REFRESH MATERIALIZED + VIEW`). Trades write-amplification for refresh latency; reconcile becomes + "refresh the view" rather than per-row diff. + +**Naming pathway (so future work has somewhere to land, no build now):** +keep `CachedCountField(IntegerField)` as the *eager write-through scalar* +marker. Reserve a sibling concept for *lazy / invalidatable* caches that +carry a freshness signal — a class such as `LazyCachedField` / +`InvalidatableCachedField`, distinguished by an associated `computed_at` +companion (or documented NULL-sentinel contract). The discovery mechanism +stays one `_meta.get_fields()` + `isinstance` sweep; the **marker class +hierarchy tells the reconciler which strategy applies**: eager → set-based +diff (approach 1); lazy → check freshness/NULL, then recompute or merely +invalidate (approach 2); rollup-backed → refresh the aggregate. One +enumeration, per-class strategy. This keeps #1301's marker the right shape +and leaves a clear, named path to the expensive cases without widening the +current PR. diff --git a/docs/superpowers/specs/2026-05-14-cached-counts-update-method-design.md b/docs/superpowers/specs/2026-05-14-cached-counts-update-method-design.md new file mode 100644 index 000000000..5721776e6 --- /dev/null +++ b/docs/superpowers/specs/2026-05-14-cached-counts-update-method-design.md @@ -0,0 +1,183 @@ +# Cached counts: `update_cached_counts` method design + +**Date:** 2026-05-14 +**Context:** Follow-up to PR #1301 takeaway-review feedback. Replace per-source-table dedup state + per-model Celery refresh tasks with a generic instance/queryset method that wraps `update_calculated_fields(save=True)`. + +## Goals + +1. Single source of truth for "recompute and persist this row's cached counts" across the codebase. +2. Caller-controlled sync vs async (`run_async=True` default). +3. Per-(model, pk) dedup so high-volume signal fan-out collapses to one task per affected row, regardless of how many source-row writes triggered it. +4. No new concepts at the field declaration site: `CachedCountField` marker, model `update_calculated_fields` body, and the periodic reconcile task stay as-is. + +## Non-goals + +- Declarative `invalidate_on=[Detection, ...]` on field. Deferred to follow-up; the registry would sit on top of this method. +- Plugging the bulk_create / bulk_update / raw-SQL blind spot. That stays the responsibility of `reconcile_cached_counts_task` and inline calls in ML worker code (`pipeline.save_results`). +- Splitting `update_calculated_fields` into "just counts" vs "derived state (S3 sums, first/last timestamps)". Wrapper stays thin today; semantic split is a separate concern when refreshing one drift forces a full S3 scan and we notice. + +## Architecture + +### New module: `ami/base/cached_counts.py` + +Per-connection dedup set keyed by `(model_label, pk)`. One `transaction.on_commit` hook per connection drains the set and dispatches the generic Celery task once per unique `(model_label, pk)`. + +```python +_PENDING_ATTR = "_pending_cached_count_recomputes" + +def _schedule_recompute(model: type[models.Model], pk: Any) -> None: + pending = getattr(connection, _PENDING_ATTR, None) + is_new = pending is None + if is_new: + pending = set() + setattr(connection, _PENDING_ATTR, pending) + pending.add((model._meta.label, pk)) + if is_new: + # Outside an atomic block, on_commit fires synchronously at + # registration time — the add above must precede it. + transaction.on_commit(_flush_pending_recomputes) + + +def _flush_pending_recomputes() -> None: + pending = getattr(connection, _PENDING_ATTR, set()) + try: + delattr(connection, _PENDING_ATTR) + except AttributeError: + pass + for label, pk in pending: + recompute_cached_counts_task.delay(label, pk) + + +@shared_task(ignore_result=True) +def recompute_cached_counts_task(model_label: str, pk: Any) -> None: + model = apps.get_model(model_label) + try: + instance = model.objects.get(pk=pk) + except model.DoesNotExist: + return + instance.update_calculated_fields(save=True) +``` + +### `BaseModel.update_cached_counts(run_async=True)` + +```python +class BaseModel(models.Model): + def update_cached_counts(self, run_async: bool = True) -> None: + if run_async: + _schedule_recompute(type(self), self.pk) + return + self.update_calculated_fields(save=True) +``` + +### `BaseQuerySet.update_cached_counts(run_async=True)` + +```python +class BaseQuerySet(QuerySet): + def update_cached_counts(self, run_async: bool = True) -> None: + for pk in self.values_list("pk", flat=True): + if run_async: + _schedule_recompute(self.model, pk) + else: + self.model.objects.get(pk=pk).update_calculated_fields(save=True) +``` + +## Call site changes + +### `ami/main/signals.py` + +Detection post_save/post_delete handler: + +```python +@receiver(post_save, sender=Detection) +@receiver(post_delete, sender=Detection) +def update_collection_counts_on_detection_change(sender, instance, **kwargs): + if not instance.source_image_id: + return + SourceImageCollection.objects.filter(images__id=instance.source_image_id).update_cached_counts() +``` + +m2m_changed on `SourceImageCollection.images.through`: + +```python +@receiver(m2m_changed, sender=SourceImageCollection.images.through) +def update_collection_counts_on_m2m(sender, instance, action, **kwargs): + if action in ("post_add", "post_remove", "post_clear"): + instance.update_cached_counts() +``` + +Project default-filter cascade (stays hand-rolled; cascade is to children, not parents): + +```python +def refresh_cached_counts_for_project(project: Project): + Event.objects.filter(project=project).update_cached_counts() + Deployment.objects.filter(project=project).update_cached_counts() + SourceImage.objects.filter(project=project).update_cached_counts() +``` + +### `ami/main/tasks.py` + +Drop `refresh_collection_cached_counts` entirely. `refresh_project_cached_counts` can also drop; the per-project cascade now schedules per-row tasks directly from the signal via the queryset method's `run_async=True` default. Reconcile task stays. + +### `ami/main/checks/cached_counts.py` reconcile loop + +```python +# before +instance.update_calculated_fields(save=True) +# after +instance.update_cached_counts(run_async=False) +``` + +Synchronous because reconcile already runs in a Celery task and we want the repair to complete before the result is reported. + +### `ami/ml/models/pipeline.py` (worker context) + +Stays as-is. Already runs in Celery and already dedupes via `.distinct()` on the collection queryset. Could optionally swap `collection.update_calculated_fields(save=True)` → `collection.update_cached_counts(run_async=False)` for stylistic unification — non-blocking on this PR. + +## What goes away + +- `_PENDING_SOURCE_IMAGE_IDS_ATTR` constant +- `_flush_pending_collection_refreshes` helper +- `_schedule_collection_refresh_for_source_image` helper +- `refresh_collection_cached_counts` task +- `refresh_project_cached_counts` task (its body becomes 3 queryset calls in the signal handler) + +## What stays + +- `CachedCountField` marker class +- Per-model `update_calculated_fields(save=True)` bodies (the actual recompute logic) +- Periodic `reconcile_cached_counts_task` and the integrity check module +- Inline calls in `pipeline.save_results` (worker-context, already deduped) + +## Cost of adding the next cached count + +Before: new field + recompute in `update_calculated_fields` + per-connection dedup attr + flush helper + Celery task + signal handler wiring (~6 things, ~50 LOC). + +After: new field + recompute in `update_calculated_fields` + signal handler calling `.update_cached_counts()` (~3 things, ~10 LOC). + +## Risks + +1. **bulk_create / bulk_update skip signals.** Unchanged from current state. Reconcile task is the safety net. Cachalot accepts the same boundary at the SQL-compiler patch layer (raw cursor coverage is opt-in via `CACHALOT_INVALIDATE_RAW`). +2. **Project default-filter cascade fans out to thousands of children.** Today it's one task; under this design it becomes N small tasks. Net cost is slightly higher (more queue overhead) but each task is bounded and parallelizable. Separate issue from this PR. +3. **`update_calculated_fields` on Deployment does S3-sum + first/last timestamp work alongside the counts.** Refreshing drift on one count therefore triggers an S3 query. Acceptable today; flagged for future split. +4. **`async` is a Python reserved word.** Use `run_async` to match existing precedent (`process_single_source_image(run_async=True)`). + +## Migration path + +This PR (or a follow-up commit on PR #1301): + +1. Create `ami/base/cached_counts.py` with `_schedule_recompute`, `_flush_pending_recomputes`, `recompute_cached_counts_task`. +2. Add `update_cached_counts` to `BaseModel` and `BaseQuerySet` in `ami/base/models.py`. +3. Refactor `ami/main/signals.py` — drop dedup helpers, switch handlers to queryset method. +4. Refactor `ami/main/tasks.py` — drop `refresh_collection_cached_counts` and `refresh_project_cached_counts`. +5. Update `ami/main/checks/cached_counts.py` reconcile loop. +6. Run existing tests in `ami/main/tests.py` (`test_source_image_cached_counts_refresh_on_threshold_change` etc.) to confirm parity. + +## Tests + +Existing tests cover: +- Threshold-change signal triggers refresh +- Detection post_save triggers per-collection refresh +- m2m_changed triggers refresh +- Bulk write drift is caught by reconcile + +These should pass unchanged. One new test: per-connection dedup collapses N detection writes to ≤ N tasks, where N is the number of distinct affected target rows. (PR #1301 has the dedup test for the old code path; rewrite to assert against the new generic task.)