Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ami/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@
from ami.users.models import User


class CachedCountField(models.IntegerField):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mohamedelabbas1996 I think we discussed a field type for cached counts last year. it finally happened!

"""Denormalized count of related rows.

Marker subclass so cached aggregate columns can be discovered via
``Model._meta.get_fields()`` (see ``ami.main.checks.cached_counts``
for the periodic drift-reconciliation check). Column type is unchanged
from ``IntegerField`` — the AlterField migrations that introduce this
subclass are no-op SQL. Mixing ``CachedCountField`` and plain
``IntegerField`` on the same model is fine, but a future contributor
adding a non-cached IntegerField next to a cached one will see an
AlterField in their migration; that's expected, not a bug.
"""

description = "Cached count of related rows"


def has_one_to_many_project_relation(model: type[models.Model]) -> bool:
"""
Returns True if the model has any ForeignKey or OneToOneField relationship to Project.
Expand Down
8 changes: 1 addition & 7 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,7 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin):
Endpoint for viewing capture sets or samples of captures.
"""

queryset = (
SourceImageCollection.objects.all()
.with_source_images_count() # type: ignore
.with_source_images_with_detections_count()
.with_source_images_processed_count()
.prefetch_related("jobs")
)
queryset = SourceImageCollection.objects.all().prefetch_related("jobs")
serializer_class = SourceImageCollectionSerializer
permission_classes = [
ObjectPermission,
Expand Down
105 changes: 105 additions & 0 deletions ami/main/checks/cached_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Reconcile drift on ``CachedCountField`` columns.

Cached count columns (e.g. ``Deployment.captures_count``,
``SourceImageCollection.source_images_count``) are kept current via signals
and explicit ``update_calculated_fields`` calls. Bulk write paths that skip
signals — ``bulk_create``, ``bulk_update``, raw SQL, some ML post-processors
— silently drift the stored value out of sync with the underlying rows.

This check discovers every model that declares one or more
``CachedCountField`` columns, finds rows whose stored values disagree with
a fresh recompute via ``instance.update_calculated_fields(save=False)``,
and either reports or repairs them.

Run via ``manage.py check_data_integrity`` (when PR #1188 lands) or via
the ``reconcile_cached_counts`` Celery task.
"""

from __future__ import annotations

import logging
from collections.abc import Iterator

from django.apps import apps

from ami.base.models import BaseModel, CachedCountField
from ami.main.checks.schemas import IntegrityCheckResult

logger = logging.getLogger(__name__)


def _cached_count_field_names(model: type[BaseModel]) -> list[str]:
fields = model._meta.get_fields() # type: ignore[attr-defined]
return [f.name for f in fields if isinstance(f, CachedCountField) and f.name]


def discover_cached_count_fields() -> dict[type[BaseModel], list[str]]:
"""Return models that declare one or more ``CachedCountField`` columns."""
result: dict[type[BaseModel], list[str]] = {}
for model in apps.get_models():
if not issubclass(model, BaseModel):
continue
cached = _cached_count_field_names(model)
if cached:
result[model] = cached
return result


def _scope_to_project(qs, model: type[BaseModel], project_id: int | None):
if project_id is None:
return qs
project_accessor = model.get_project_accessor()
if project_accessor and project_accessor != "projects":
return qs.filter(**{f"{project_accessor}_id": project_id})
return qs


def find_stale_cached_counts(
model: type[BaseModel],
project_id: int | None = None,
) -> Iterator[tuple[BaseModel, dict[str, int | None], dict[str, int | None]]]:
"""Yield ``(instance, stored, computed)`` for rows whose cached counts drift.

Iterates the queryset row-by-row and calls ``update_calculated_fields(save=False)``
on a fresh copy so the stored row stays untouched. Heavy on large tables;
callers should scope by ``project_id`` whenever the check is interactive.
"""
cached_fields = _cached_count_field_names(model)
if not cached_fields:
return
qs = _scope_to_project(model.objects.all(), model, project_id)
for instance in qs.iterator():
stored = {f: getattr(instance, f) for f in cached_fields}
instance.update_calculated_fields(save=False)
computed = {f: getattr(instance, f) for f in cached_fields}
if stored != computed:
yield instance, stored, computed


def reconcile_cached_counts(
model: type[BaseModel] | None = None,
project_id: int | None = None,
dry_run: bool = True,
) -> IntegrityCheckResult:
"""Repair stale cached counts. Pass ``model=None`` to sweep all models."""
models_to_check = [model] if model else list(discover_cached_count_fields().keys())
result = IntegrityCheckResult()
for m in models_to_check:
for instance, stored, computed in find_stale_cached_counts(m, project_id=project_id):
result.checked += 1
logger.info(
"Stale cached counts on %s pk=%s: stored=%s computed=%s",
m.__name__,
instance.pk,
stored,
computed,
)
if dry_run:
continue
try:
instance.update_calculated_fields(save=True)
result.fixed += 1
except Exception:
logger.exception("Failed to reconcile %s pk=%s", m.__name__, instance.pk)
result.unfixable += 1
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Denormalize three counts onto ``SourceImageCollection`` so the list endpoint
reads them in O(1) instead of running 3 correlated count subqueries per row.

Backfill uses a single GROUP BY over the M2M with FILTER clauses to compute
all three counts in one pass. ``with_det`` checks for a valid (non-null /
non-empty) detection bbox to match the runtime
``NULL_DETECTIONS_FILTER`` semantics in ``ami/main/models.py``.

``atomic = False`` so the long UPDATE can run outside a single transaction
on production-sized M2M tables.
"""

from django.db import migrations, models


def backfill_counts(apps, schema_editor):
schema_editor.execute(
"""
UPDATE main_sourceimagecollection sc
SET source_images_count = c.total,
source_images_processed_count = c.processed,
source_images_with_detections_count = c.with_det
FROM (
SELECT msci.sourceimagecollection_id AS coll_id,
COUNT(*) AS total,
COUNT(*) FILTER (
WHERE EXISTS (
SELECT 1 FROM main_detection d
WHERE d.source_image_id = si.id
)
) AS processed,
COUNT(*) FILTER (
WHERE EXISTS (
SELECT 1 FROM main_detection d
WHERE d.source_image_id = si.id
AND d.bbox IS NOT NULL
AND d.bbox::text <> '[]'
)
) AS with_det
FROM main_sourceimagecollection_images msci
INNER JOIN main_sourceimage si ON si.id = msci.sourceimage_id
GROUP BY msci.sourceimagecollection_id
) c
WHERE sc.id = c.coll_id;
"""
)
# Collections with no images: paginated SELECTs returned 0 via Coalesce; keep
# them populated rather than NULL so the column reads stay consistent.
schema_editor.execute(
"""
UPDATE main_sourceimagecollection
SET source_images_count = 0,
source_images_processed_count = 0,
source_images_with_detections_count = 0
WHERE source_images_count IS NULL;
"""
)


def reverse_noop(apps, schema_editor):
pass


class Migration(migrations.Migration):
atomic = False

dependencies = [
("main", "0084_revoke_delete_job_from_roles"),
]

operations = [
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_count",
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_with_detections_count",
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_processed_count",
field=models.IntegerField(default=0),
),
Comment on lines +24 to +38
migrations.RunPython(backfill_counts, reverse_noop),
]
73 changes: 73 additions & 0 deletions ami/main/migrations/0086_use_cached_count_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Generated by Django 4.2.10 on 2026-05-13 20:05

import ami.base.models
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("main", "0085_denormalize_sourceimagecollection_counts"),
]

operations = [
migrations.AlterField(
model_name="deployment",
name="captures_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="events_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="occurrences_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="taxa_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="captures_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="occurrences_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="sourceimage",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_count",
field=ami.base.models.CachedCountField(default=0),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_processed_count",
field=ami.base.models.CachedCountField(default=0),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_with_detections_count",
field=ami.base.models.CachedCountField(default=0),
),
]
Loading
Loading