Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions ami/base/cached_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Helpers for scheduling cached-count recomputes after a transaction commits.

Wraps the per-connection dedup + ``transaction.on_commit`` plumbing that
``BaseModel.update_cached_counts`` and ``BaseQuerySet.update_cached_counts``
build on. The actual recompute body lives in each model's
``update_calculated_fields(save=True)`` implementation; this module only
handles scheduling.

Per-(model_label, pk) dedup means N writes affecting the same target row
collapse to one task, regardless of how many signal handlers fire in the
transaction. The dedup set lives on the active DB connection (thread-local
in Django's default setup) and is drained by a single ``on_commit`` hook.
"""

from __future__ import annotations

import logging
from typing import Any

from django.db import connection, transaction

logger = logging.getLogger(__name__)

_PENDING_ATTR = "_pending_cached_count_recomputes"


def schedule_recompute(model_label: str, pk: Any) -> None:
"""Queue a ``(model_label, pk)`` for recompute at the next commit.

The pending set lives on the active DB connection; ``transaction.on_commit``
fires the flush. ``_flush_pending_recomputes`` is idempotent — the first
call drains the set; subsequent ones no-op — so we register on_commit on
every call. That keeps us correct across transaction rollbacks (which
discard registered on_commits but leave attributes on ``connection``
untouched, e.g. between a rolled-back ``TestCase`` and a fresh
``TransactionTestCase``).

Outside an atomic block, ``on_commit`` fires synchronously at
registration time — so the ``add`` below must precede the
``transaction.on_commit`` call or the flush sees an empty set.
"""
pending: set[tuple[str, Any]] | None = getattr(connection, _PENDING_ATTR, None)
if pending is None:
pending = set()
setattr(connection, _PENDING_ATTR, pending)
pending.add((model_label, pk))
transaction.on_commit(_flush_pending_recomputes)


def _flush_pending_recomputes() -> None:
"""Drain the per-connection dedup set; dispatch one task per ``(model, pk)``."""
from ami.main.tasks import recompute_cached_counts_task

pending: set[tuple[str, Any]] = getattr(connection, _PENDING_ATTR, set())
try:
delattr(connection, _PENDING_ATTR)
except AttributeError:
pass
for model_label, pk in pending:
recompute_cached_counts_task.delay(model_label, pk)
58 changes: 58 additions & 0 deletions ami/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@
from ami.users.models import User


class CachedCountField(models.IntegerField):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mohamedelabbas1996 I think we discussed a field type for cached counts last year. it finally happened!

"""Denormalized count of related rows.

Marker subclass so cached aggregate columns can be discovered via
``Model._meta.get_fields()`` + ``isinstance(f, CachedCountField)`` by
future cross-cutting tasks (admin display, periodic drift
reconciliation). Column type is unchanged
from ``IntegerField`` — the AlterField migrations that introduce this
subclass are no-op SQL. Mixing ``CachedCountField`` and plain
``IntegerField`` on the same model is fine, but a future contributor
adding a non-cached IntegerField next to a cached one will see an
AlterField in their migration; that's expected, not a bug.
"""

description = "Cached count of related rows"


def has_one_to_many_project_relation(model: type[models.Model]) -> bool:
"""
Returns True if the model has any ForeignKey or OneToOneField relationship to Project.
Expand Down Expand Up @@ -40,6 +57,29 @@ def has_many_to_many_project_relation(model: type[models.Model]) -> bool:


class BaseQuerySet(QuerySet):
def update_cached_counts(self, run_async: bool = True) -> None:
"""Recompute cached count columns for every row in the queryset.

With ``run_async=True`` (default), each row is queued for recompute
via ``ami.base.cached_counts.schedule_recompute`` and dispatched as
a single ``recompute_cached_counts_task`` per ``(model, pk)`` after
the surrounding transaction commits. Repeated calls within the same
transaction dedupe through a per-connection set.

With ``run_async=False``, each row is loaded and recomputed inline.
Suitable for Celery-worker contexts (reconcile task, ML pipeline
finalize) where the caller has already taken the latency hit.
"""
from ami.base.cached_counts import schedule_recompute

model_label = self.model._meta.label
for pk in self.values_list("pk", flat=True):
if run_async:
schedule_recompute(model_label, pk)
else:
instance = self.model.objects.get(pk=pk)
instance.update_calculated_fields(save=True)

def visible_for_user(self, user: User | AnonymousUser) -> QuerySet:
"""
Filter queryset to include only objects whose related draft projects
Expand Down Expand Up @@ -166,6 +206,24 @@ def update_calculated_fields(self, *args, **kwargs):
"""Update calculated fields specific to each model."""
pass

def update_cached_counts(self, run_async: bool = True) -> None:
"""Recompute this row's cached count columns.

With ``run_async=True`` (default), schedule a Celery task to run
after the surrounding transaction commits. Per-(model, pk) dedup
on the active DB connection collapses repeated calls within the
same transaction into a single task.

With ``run_async=False``, recompute inline by calling
``update_calculated_fields(save=True)`` directly.
"""
from ami.base.cached_counts import schedule_recompute

if run_async:
schedule_recompute(self._meta.label, self.pk)
return
self.update_calculated_fields(save=True)

def _get_object_perms(self, user):
"""
Get the object-level permissions for the user on this instance.
Expand Down
6 changes: 6 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,12 @@ class Meta:
"created_at",
"updated_at",
]
# Denormalized columns kept in sync by signals; never client-writable.
read_only_fields = [
"source_images_count",
"source_images_with_detections_count",
"source_images_processed_count",
]

def get_permissions(self, instance, instance_data):
request: Request = self.context["request"]
Expand Down
8 changes: 1 addition & 7 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,7 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin):
Endpoint for viewing capture sets or samples of captures.
"""

queryset = (
SourceImageCollection.objects.all()
.with_source_images_count() # type: ignore
.with_source_images_with_detections_count()
.with_source_images_processed_count()
.prefetch_related("jobs")
)
queryset = SourceImageCollection.objects.all().prefetch_related("jobs")
serializer_class = SourceImageCollectionSerializer
permission_classes = [
ObjectPermission,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Add three denormalized count columns to ``SourceImageCollection`` so the list
endpoint reads them in O(1) instead of running 3 correlated count subqueries
per row.

Schema only. The backfill runs in the separate, non-atomic, re-runnable
migration ``0086_backfill_sourceimagecollection_counts`` so an interrupted
backfill on production-sized data cannot leave the schema half-applied
(columns added but migration unrecorded -> retry fails on duplicate column).

``AddField`` with a constant ``default`` is a metadata-only operation on
PostgreSQL 11+, so this is safe to run atomically even on large tables.
"""

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("main", "0084_revoke_delete_job_from_roles"),
]

operations = [
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_count",
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_with_detections_count",
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_processed_count",
field=models.IntegerField(default=0),
),
Comment on lines +24 to +38
]
65 changes: 65 additions & 0 deletions ami/main/migrations/0086_backfill_sourceimagecollection_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Backfill the denormalized ``SourceImageCollection`` count columns added in
0085.

Split from the schema migration on purpose: this is the slow step on
production-sized M2M tables. ``atomic = False`` lets the UPDATE run outside a
single transaction, and the UPDATE writes absolute computed values (not
deltas) so it is idempotent — safe to re-run if interrupted. Collections with
no images keep the column ``default=0`` from 0085 (the GROUP BY only emits
rows for collections that have images).

``with_det`` checks for a valid (non-null / non-empty) detection bbox to match
the runtime ``NULL_DETECTIONS_FILTER`` semantics in ``ami/main/models.py``.
"""

from django.db import migrations


def backfill_counts(apps, schema_editor):
schema_editor.execute(
"""
UPDATE main_sourceimagecollection sc
SET source_images_count = c.total,
source_images_processed_count = c.processed,
source_images_with_detections_count = c.with_det
FROM (
SELECT msci.sourceimagecollection_id AS coll_id,
COUNT(*) AS total,
COUNT(*) FILTER (
WHERE EXISTS (
SELECT 1 FROM main_detection d
WHERE d.source_image_id = si.id
)
) AS processed,
COUNT(*) FILTER (
WHERE EXISTS (
SELECT 1 FROM main_detection d
WHERE d.source_image_id = si.id
AND d.bbox IS NOT NULL
AND d.bbox::text <> '[]'
)
) AS with_det
FROM main_sourceimagecollection_images msci
INNER JOIN main_sourceimage si ON si.id = msci.sourceimage_id
GROUP BY msci.sourceimagecollection_id
) c
WHERE sc.id = c.coll_id;
"""
)


def reverse_noop(apps, schema_editor):
pass


class Migration(migrations.Migration):
atomic = False

dependencies = [
("main", "0085_denormalize_sourceimagecollection_counts"),
]

operations = [
migrations.RunPython(backfill_counts, reverse_noop),
]
73 changes: 73 additions & 0 deletions ami/main/migrations/0087_use_cached_count_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Generated by Django 4.2.10 on 2026-05-13 20:05

import ami.base.models
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("main", "0086_backfill_sourceimagecollection_counts"),
]

operations = [
migrations.AlterField(
model_name="deployment",
name="captures_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="events_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="occurrences_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="taxa_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="captures_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="occurrences_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="sourceimage",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_count",
field=ami.base.models.CachedCountField(default=0),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_processed_count",
field=ami.base.models.CachedCountField(default=0),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_with_detections_count",
field=ami.base.models.CachedCountField(default=0),
),
]
Loading
Loading