Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ami/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
from ami.users.models import User


class CachedCountField(models.IntegerField):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mohamedelabbas1996 I think we discussed a field type for cached counts last year. it finally happened!

"""Denormalized count of related rows.

Marker subclass so cached aggregate columns can be discovered via
``Model._meta.get_fields()`` (e.g. for refresh tasks, admin display, or
list-endpoint defer()). Values may be stale or null between
``update_calculated_fields`` calls — readers should not assume freshness.
"""

description = "Cached count of related rows"


def has_one_to_many_project_relation(model: type[models.Model]) -> bool:
"""
Returns True if the model has any ForeignKey or OneToOneField relationship to Project.
Expand Down
8 changes: 1 addition & 7 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,7 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin):
Endpoint for viewing capture sets or samples of captures.
"""

queryset = (
SourceImageCollection.objects.all()
.with_source_images_count() # type: ignore
.with_source_images_with_detections_count()
.with_source_images_processed_count()
.prefetch_related("jobs")
)
queryset = SourceImageCollection.objects.all().prefetch_related("jobs")
serializer_class = SourceImageCollectionSerializer
permission_classes = [
ObjectPermission,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Denormalize three counts onto ``SourceImageCollection`` so the list endpoint
reads them in O(1) instead of running 3 correlated count subqueries per row.

Backfill uses a single GROUP BY over the M2M with FILTER clauses to compute
all three counts in one pass. ``with_det`` checks for a valid (non-null /
non-empty) detection bbox to match the runtime
``NULL_DETECTIONS_FILTER`` semantics in ``ami/main/models.py``.

``atomic = False`` so the long UPDATE can run outside a single transaction
on production-sized M2M tables.
"""

from django.db import migrations, models


def backfill_counts(apps, schema_editor):
schema_editor.execute(
"""
UPDATE main_sourceimagecollection sc
SET source_images_count = c.total,
source_images_processed_count = c.processed,
source_images_with_detections_count = c.with_det
FROM (
SELECT msci.sourceimagecollection_id AS coll_id,
COUNT(*) AS total,
COUNT(*) FILTER (
WHERE EXISTS (
SELECT 1 FROM main_detection d
WHERE d.source_image_id = si.id
)
) AS processed,
COUNT(*) FILTER (
WHERE EXISTS (
SELECT 1 FROM main_detection d
WHERE d.source_image_id = si.id
AND d.bbox IS NOT NULL
AND d.bbox::text <> '[]'
)
) AS with_det
FROM main_sourceimagecollection_images msci
INNER JOIN main_sourceimage si ON si.id = msci.sourceimage_id
GROUP BY msci.sourceimagecollection_id
) c
WHERE sc.id = c.coll_id;
"""
)
# Collections with no images: paginated SELECTs returned 0 via Coalesce; keep
# them populated rather than NULL so the column reads stay consistent.
schema_editor.execute(
"""
UPDATE main_sourceimagecollection
SET source_images_count = 0,
source_images_processed_count = 0,
source_images_with_detections_count = 0
WHERE source_images_count IS NULL;
"""
)


def reverse_noop(apps, schema_editor):
pass


class Migration(migrations.Migration):
atomic = False

dependencies = [
("main", "0084_revoke_delete_job_from_roles"),
]

operations = [
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_count",
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_with_detections_count",
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name="sourceimagecollection",
name="source_images_processed_count",
field=models.IntegerField(default=0),
),
Comment on lines +24 to +38
migrations.RunPython(backfill_counts, reverse_noop),
]
73 changes: 73 additions & 0 deletions ami/main/migrations/0086_use_cached_count_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Generated by Django 4.2.10 on 2026-05-13 20:05

import ami.base.models
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("main", "0085_denormalize_sourceimagecollection_counts"),
]

operations = [
migrations.AlterField(
model_name="deployment",
name="captures_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="events_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="occurrences_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="deployment",
name="taxa_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="captures_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="event",
name="occurrences_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="sourceimage",
name="detections_count",
field=ami.base.models.CachedCountField(blank=True, null=True),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_count",
field=ami.base.models.CachedCountField(default=0),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_processed_count",
field=ami.base.models.CachedCountField(default=0),
),
migrations.AlterField(
model_name="sourceimagecollection",
name="source_images_with_detections_count",
field=ami.base.models.CachedCountField(default=0),
),
]
88 changes: 39 additions & 49 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import ami.tasks
import ami.utils
from ami.base.fields import DateStringField
from ami.base.models import BaseModel, BaseQuerySet
from ami.base.models import BaseModel, BaseQuerySet, CachedCountField
from ami.main import charts
from ami.main.models_future.filters import (
build_occurrence_default_filters_q,
Expand Down Expand Up @@ -754,11 +754,11 @@ class Deployment(BaseModel):
# data_source_last_check_notes = models.TextField(max_length=255, blank=True, null=True)

# Pre-calculated values
events_count = models.IntegerField(blank=True, null=True)
occurrences_count = models.IntegerField(blank=True, null=True)
captures_count = models.IntegerField(blank=True, null=True)
detections_count = models.IntegerField(blank=True, null=True)
taxa_count = models.IntegerField(blank=True, null=True)
events_count = CachedCountField(blank=True, null=True)
occurrences_count = CachedCountField(blank=True, null=True)
captures_count = CachedCountField(blank=True, null=True)
detections_count = CachedCountField(blank=True, null=True)
taxa_count = CachedCountField(blank=True, null=True)
first_capture_timestamp = models.DateTimeField(blank=True, null=True)
last_capture_timestamp = models.DateTimeField(blank=True, null=True)

Expand Down Expand Up @@ -1155,9 +1155,9 @@ class Event(BaseModel):
occurrences: models.QuerySet["Occurrence"]

# Pre-calculated values
captures_count = models.IntegerField(blank=True, null=True)
detections_count = models.IntegerField(blank=True, null=True)
occurrences_count = models.IntegerField(blank=True, null=True)
captures_count = CachedCountField(blank=True, null=True)
detections_count = CachedCountField(blank=True, null=True)
occurrences_count = CachedCountField(blank=True, null=True)
calculated_fields_updated_at = models.DateTimeField(blank=True, null=True)

class Meta:
Expand Down Expand Up @@ -1942,7 +1942,7 @@ class SourceImage(BaseModel):
test_image = models.BooleanField(default=False)

# Precaclulated values
detections_count = models.IntegerField(null=True, blank=True)
detections_count = CachedCountField(null=True, blank=True)

project = models.ForeignKey(Project, on_delete=models.SET_NULL, null=True, related_name="captures")
deployment = models.ForeignKey(Deployment, on_delete=models.SET_NULL, null=True, related_name="captures")
Expand Down Expand Up @@ -4093,32 +4093,6 @@ def html(self) -> str:


class SourceImageCollectionQuerySet(BaseQuerySet):
def with_source_images_count(self):
return self.annotate(
source_images_count=models.Count(
"images",
distinct=True,
)
)

def with_source_images_with_detections_count(self):
return self.annotate(
source_images_with_detections_count=models.Count(
"images",
filter=(~models.Q(images__detections__bbox__isnull=True) & ~models.Q(images__detections__bbox=[])),
distinct=True,
)
)

def with_source_images_processed_count(self):
return self.annotate(
source_images_processed_count=models.Count(
"images",
filter=models.Q(images__detections__isnull=False),
distinct=True,
)
)

def with_source_images_processed_by_algorithm_count(self, algorithm_id: int):
return self.annotate(
source_images_processed_by_algorithm_count=models.Count(
Expand Down Expand Up @@ -4205,6 +4179,12 @@ class SourceImageCollection(BaseModel):
default=dict,
)

# Denormalized counts. Kept in sync via m2m_changed and pipeline-completion
# hooks. Reads are O(1).
source_images_count = CachedCountField(default=0)
source_images_with_detections_count = CachedCountField(default=0)
source_images_processed_count = CachedCountField(default=0)
Comment on lines +4184 to +4186
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch


objects = SourceImageCollectionManager()

jobs: models.QuerySet["Job"]
Expand All @@ -4219,19 +4199,6 @@ def infer_dataset_type(self):
def dataset_type(self):
return self.infer_dataset_type()

def source_images_count(self) -> int | None:
# This should always be pre-populated using queryset annotations
# return self.images.count()
return None

def source_images_with_detections_count(self) -> int | None:
# This should always be pre-populated using queryset annotations
return None

def source_images_processed_count(self) -> int | None:
# This should always be pre-populated using queryset annotations
return None

def occurrences_count(self) -> int | None:
# This should always be pre-populated using queryset annotations
return None
Expand All @@ -4240,6 +4207,29 @@ def taxa_count(self) -> int | None:
# This should always be pre-populated using queryset annotations
return None

def get_source_image_counts(self) -> dict[str, int]:
"""Return the 3 source-image counts as a dict. Single aggregate query; does not write to the DB."""
valid_det = Detection.objects.filter(source_image=models.OuterRef("pk")).exclude(NULL_DETECTIONS_FILTER)
any_det = Detection.objects.filter(source_image=models.OuterRef("pk"))
counts = self.images.annotate(
_has_any_det=Exists(any_det),
_has_valid_det=Exists(valid_det),
).aggregate(
source_images_count=models.Count("id"),
source_images_processed_count=models.Count("id", filter=models.Q(_has_any_det=True)),
source_images_with_detections_count=models.Count("id", filter=models.Q(_has_valid_det=True)),
)
return counts

def update_calculated_fields(self, save: bool = False) -> None:
"""Recompute the 3 denormalized source-image count columns."""
counts = self.get_source_image_counts()
self.source_images_count = counts["source_images_count"]
self.source_images_processed_count = counts["source_images_processed_count"]
self.source_images_with_detections_count = counts["source_images_with_detections_count"]
if save:
SourceImageCollection.objects.filter(pk=self.pk).update(**counts)

def get_queryset(
self,
*args,
Expand Down
30 changes: 28 additions & 2 deletions ami/main/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from django.contrib.auth.models import Group
from django.db import transaction
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.db.models.signals import m2m_changed, post_delete, post_save, pre_delete, pre_save
from django.dispatch import receiver
from guardian.shortcuts import assign_perm

from ami.main.models import Project
from ami.main.models import Detection, Project, SourceImageCollection
from ami.main.tasks import refresh_project_cached_counts
from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project

Expand Down Expand Up @@ -197,3 +197,29 @@ def exclude_taxa_updated(sender, instance: Project, action, **kwargs):
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)


# ============================================================================
# SourceImageCollection Denormalized Counts
# ============================================================================


@receiver(m2m_changed, sender=SourceImageCollection.images.through)
def update_collection_counts_on_m2m(sender, instance, action, **kwargs):
"""Recompute denormalized counts when images are added to or removed from a collection."""
if action in ("post_add", "post_remove", "post_clear"):
instance.update_calculated_fields(save=True)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


@receiver(post_save, sender=Detection)
@receiver(post_delete, sender=Detection)
def update_collection_counts_on_detection_change(sender, instance, **kwargs):
"""Keep processed / with-detections counts fresh on per-row Detection writes.

`bulk_create` skips signals, so ML pipelines must call `update_calculated_fields`
explicitly after their batch writes (see `ami.ml.models.pipeline.save_results`).
"""
if not instance.source_image_id:
return
for collection in SourceImageCollection.objects.filter(images__id=instance.source_image_id).distinct():
collection.update_calculated_fields(save=True)
Loading
Loading