Skip to content
Draft
33 changes: 33 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,3 +1749,36 @@ class TopIdentifiersResponseSerializer(serializers.Serializer):

project_id = serializers.IntegerField()
top_identifiers = UserIdentificationCountSerializer(many=True)


class ModelAgreementSerializer(serializers.Serializer):
"""Verified / agreement rates over the filtered Occurrence set.

`agreed_exact_count` is a subset of `agreed_under_order_count` by
construction — an exact match implies an LCA at SPECIES, which is
deeper than ORDER. `*_pct` percentages are 0.0..1.0 (not 0..100).

Denominator note: `agreed_*_pct` divide by `verified_with_prediction_count`
(verified occurrences that *also* have a machine prediction), NOT by
`verified_count`. A verified occurrence with no machine prediction can't
agree or disagree — including it in the denominator would drag the rate
down without representing actual model disagreement. `no_prediction_count`
is surfaced so the consumer can see how many such occurrences exist.
"""

project_id = serializers.IntegerField()
total_occurrences = serializers.IntegerField()
verified_count = serializers.IntegerField(help_text="Occurrences with at least one non-withdrawn identification.")
verified_pct = serializers.FloatField(help_text="verified_count / total_occurrences")
verified_with_prediction_count = serializers.IntegerField(
help_text="Verified occurrences that also have a machine prediction (denominator for agreed_*_pct)."
)
no_prediction_count = serializers.IntegerField(
help_text="Verified occurrences with no machine prediction (excluded from agreement denominator)."
)
agreed_exact_count = serializers.IntegerField()
agreed_exact_pct = serializers.FloatField(help_text="agreed_exact_count / verified_with_prediction_count")
agreed_under_order_count = serializers.IntegerField()
agreed_under_order_pct = serializers.FloatField(
help_text="agreed_under_order_count / verified_with_prediction_count"
)
67 changes: 50 additions & 17 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ami.base.views import ProjectMixin
from ami.main.api.schemas import limit_doc_param, project_id_doc_param
from ami.main.api.serializers import TagSerializer
from ami.main.models_future.occurrence import top_identifiers_for_project
from ami.main.models_future.occurrence import model_agreement_for_project, top_identifiers_for_project
from ami.utils.requests import get_default_classification_threshold
from ami.utils.storages import ConnectionTestResult

Expand Down Expand Up @@ -72,6 +72,7 @@
EventSerializer,
EventTimelineSerializer,
IdentificationSerializer,
ModelAgreementSerializer,
OccurrenceListSerializer,
OccurrenceSerializer,
PageListSerializer,
Expand Down Expand Up @@ -1168,6 +1169,24 @@ def filter_queryset(self, request, queryset, view):
return queryset


OCCURRENCE_FILTER_BACKENDS = (
CustomOccurrenceDeterminationFilter,
OccurrenceCollectionFilter,
OccurrenceAlgorithmFilter,
OccurrenceDateFilter,
OccurrenceVerified,
OccurrenceVerifiedByMeFilter,
OccurrenceTaxaListFilter,
)

OCCURRENCE_FILTERSET_FIELDS = (
"event",
"deployment",
"determination__rank",
"detections__source_image",
)


class OccurrenceViewSet(DefaultViewSet, ProjectMixin):
"""
API endpoint that allows occurrences to be viewed or edited.
Expand All @@ -1177,22 +1196,8 @@ class OccurrenceViewSet(DefaultViewSet, ProjectMixin):
queryset = Occurrence.objects.all()

serializer_class = OccurrenceSerializer
# filter_backends = [CustomDeterminationFilter, DjangoFilterBackend, NullsLastOrderingFilter, SearchFilter]
filter_backends = DefaultViewSetMixin.filter_backends + [
CustomOccurrenceDeterminationFilter,
OccurrenceCollectionFilter,
OccurrenceAlgorithmFilter,
OccurrenceDateFilter,
OccurrenceVerified,
OccurrenceVerifiedByMeFilter,
OccurrenceTaxaListFilter,
]
filterset_fields = [
"event",
"deployment",
"determination__rank",
"detections__source_image",
]
filter_backends = DefaultViewSetMixin.filter_backends + list(OCCURRENCE_FILTER_BACKENDS)
filterset_fields = list(OCCURRENCE_FILTERSET_FIELDS)
ordering_fields = [
"created_at",
"updated_at",
Expand Down Expand Up @@ -1290,6 +1295,11 @@ class OccurrenceStatsViewSet(viewsets.GenericViewSet, ProjectMixin):

permission_classes = [IsActiveStaffOrReadOnly]
require_project = True
# Filter machinery for actions that opt into `self.filter_queryset(...)`.
# `top_identifiers` doesn't call it, so its behavior is unchanged.
queryset = Occurrence.objects.none()
filter_backends = [DjangoFilterBackend, *OCCURRENCE_FILTER_BACKENDS]
filterset_fields = list(OCCURRENCE_FILTERSET_FIELDS)

Comment thread
mihow marked this conversation as resolved.
@extend_schema(
parameters=[project_id_doc_param, limit_doc_param],
Expand Down Expand Up @@ -1320,6 +1330,29 @@ def top_identifiers(self, request):
)
return Response(serializer.data)

@extend_schema(
parameters=[project_id_doc_param],
responses=ModelAgreementSerializer,
)
@action(detail=False, methods=["get"], url_path="model-agreement")
def model_agreement(self, request):
"""Verified / human↔model agreement rates over the filtered occurrence set.

Accepts every query param the `/occurrences/` list endpoint accepts.
Reuses `apply_default_filters` so `apply_defaults=false` bypasses
project default taxa lists + score thresholds.
"""
project = self.get_active_project()
assert project is not None # require_project=True guarantees this
if not Project.objects.visible_for_user(request.user).filter(pk=project.pk).exists():
raise NotFound("Project not found.")

base_qs = Occurrence.objects.filter(project=project).valid().apply_default_filters(project, request)
filtered_qs = self.filter_queryset(base_qs)
payload = model_agreement_for_project(filtered_qs)
payload["project_id"] = project.pk
return Response(ModelAgreementSerializer(payload, context={"request": request}).data)


class TaxonTaxaListFilter(filters.BaseFilterBackend):
"""
Expand Down
143 changes: 141 additions & 2 deletions ami/main/models_future/occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,43 @@

from typing import TYPE_CHECKING

from django.db.models import Count, Prefetch, Q, QuerySet
from django.db.models import Count, OuterRef, Prefetch, Q, QuerySet, Subquery

from ami.main.models import Project, User
from ami.main.models import Project, TaxonRank, User

if TYPE_CHECKING:
from ami.main.models import Classification, Identification, Occurrence

TaxonTuple = tuple[int, str, list[dict]]


def lca_rank_between(a: TaxonTuple, b: TaxonTuple) -> TaxonRank | None:
"""Most-specific shared ancestor rank between two taxa.

Inputs are ``(taxon_id, rank_str, parents_json)`` triples where
``parents_json`` is ordered root → immediate parent (Taxon.parents_json layout).

The taxon itself counts as part of its own ancestor chain — passing the
same taxon twice returns that taxon's rank. Returns ``None`` when the two
chains share no ancestor at a real taxonomic rank.

``TaxonRank.UNKNOWN`` is excluded from the candidate set even though it
sorts after SPECIES in OrderedEnum definition order — it isn't a real
taxonomic rank and treating it as deeper-than-ORDER produces false
under-order agreements when an UNKNOWN ancestor happens to be shared.
"""
chain_a = [(p["id"], TaxonRank(p["rank"])) for p in a[2]] + [(a[0], TaxonRank(a[1]))]
chain_b_ids = {p["id"] for p in b[2]} | {b[0]}

deepest: TaxonRank | None = None
for tid, rank in chain_a:
if rank == TaxonRank.UNKNOWN:
continue
if tid in chain_b_ids:
if deepest is None or rank > deepest:
deepest = rank
return deepest


def _detections_prefetch(*, ordering: tuple[str, ...], with_source_image: bool) -> Prefetch:
from ami.main.models import Classification, Detection
Expand Down Expand Up @@ -133,6 +163,115 @@ def detection_image_urls_from_prefetch(occurrence: Occurrence, limit: int | None
return [get_media_url(det.path) for det in detections]


def model_agreement_for_project(queryset: QuerySet[Occurrence]) -> dict:
"""Verified / agreement stats over a pre-filtered Occurrence queryset.

The queryset MUST already be filtered to the project + user-supplied
filters (caller wires apply_default_filters + OccurrenceFilter). This
function adds the annotations it needs and returns a dict matching
ModelAgreementSerializer's field set (without project_id — the view
layer adds that).

"Verified" means the occurrence has at least one non-withdrawn
Identification. "Model prediction" means the Classification chosen by
BEST_MACHINE_PREDICTION_ORDER. "Under-order" agreement means the user's
taxon and the model's prediction share an ancestor at rank >= ORDER
(inclusive of ORDER itself).

Performance: the heavy work — correlated subqueries over Identification
and Classification — is scoped to the verified set, which is typically
a tiny fraction of total occurrences. Computing those subqueries over
the full filtered queryset would do 99% wasted work picking the "best
user identification" for occurrences that have none.

Step 1: total_occurrences = SQL Count(*).
Step 2: Fetch the verified set with (pk, best_user_taxon_id,
best_machine_prediction_taxon_id). Both correlated subqueries
evaluate only on verified rows.
Step 3: Bucket counts in Python (set is small).
Step 4: Dedupe disagreement to distinct (user, machine) pairs and run
one LCA per pair.

Bench against project 18 (43,149 occurrences, 45 verified): ~80ms cold.
"""
import collections

from ami.main.models import BEST_IDENTIFICATION_ORDER, Identification, Taxon

total = queryset.count()

best_user_ident = Identification.objects.filter(occurrence=OuterRef("pk"), withdrawn=False).order_by(
*BEST_IDENTIFICATION_ORDER
)

verified_rows = list(
queryset.filter(identifications__withdrawn=False)
.distinct()
.with_best_machine_prediction() # type: ignore[attr-defined]
.annotate(best_user_taxon_id=Subquery(best_user_ident.values("taxon_id")[:1]))
.values("pk", "best_machine_prediction_taxon_id", "best_user_taxon_id")
)

verified = len(verified_rows)
no_prediction = sum(1 for r in verified_rows if r["best_machine_prediction_taxon_id"] is None)
verified_with_pred = verified - no_prediction
agreed_exact = sum(
1
for r in verified_rows
if r["best_machine_prediction_taxon_id"] is not None
and r["best_user_taxon_id"] == r["best_machine_prediction_taxon_id"]
)

# Dedupe disagreement pairs so each (user_taxon, machine_taxon) LCA runs once.
pair_counts: collections.Counter = collections.Counter()
for r in verified_rows:
m_id = r["best_machine_prediction_taxon_id"]
u_id = r["best_user_taxon_id"]
if m_id is None or u_id is None or u_id == m_id:
continue
pair_counts[(u_id, m_id)] += 1

needed_taxa_ids: set[int] = set()
for u_id, m_id in pair_counts:
needed_taxa_ids.add(u_id)
needed_taxa_ids.add(m_id)

taxa_by_id: dict[int, TaxonTuple] = {}
if needed_taxa_ids:
for t in Taxon.objects.filter(pk__in=needed_taxa_ids):
parents = [
{"id": p.id, "rank": p.rank.name if hasattr(p.rank, "name") else p.rank} for p in t.parents_json
]
taxa_by_id[t.pk] = (t.pk, t.rank, parents)

under_order_disagreement_count = 0
for (u_id, m_id), count in pair_counts.items():
u = taxa_by_id.get(u_id)
m = taxa_by_id.get(m_id)
if not u or not m:
continue
lca = lca_rank_between(u, m)
if lca is not None and lca >= TaxonRank.ORDER:
under_order_disagreement_count += count

agreed_under_order = agreed_exact + under_order_disagreement_count

def _pct(num: int, denom: int) -> float:
return round(num / denom, 4) if denom else 0.0

return {
"total_occurrences": total,
"verified_count": verified,
"verified_pct": _pct(verified, total),
"verified_with_prediction_count": verified_with_pred,
"no_prediction_count": no_prediction,
"agreed_exact_count": agreed_exact,
"agreed_exact_pct": _pct(agreed_exact, verified_with_pred),
"agreed_under_order_count": agreed_under_order,
"agreed_under_order_pct": _pct(agreed_under_order, verified_with_pred),
}


def top_identifiers_for_project(project: Project) -> QuerySet[User]:
"""Project users ranked by distinct occurrences they identified.

Expand Down
Loading
Loading