Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,3 +1749,75 @@ class TopIdentifiersResponseSerializer(serializers.Serializer):

project_id = serializers.IntegerField()
top_identifiers = UserIdentificationCountSerializer(many=True)


class ModelAgreementSerializer(serializers.Serializer):
"""Verified / agreement rates over the filtered Occurrence set.

`agreed_exact_count` is a subset of `agreed_any_rank_count` by
construction — an exact match implies the LCA is the taxon itself.
`*_pct` percentages are 0.0..1.0 (not 0..100).

Denominator note: `agreed_*_pct` divide by `verified_with_prediction_count`
(verified occurrences that *also* have a machine prediction), NOT by
`verified_count`. A verified occurrence with no machine prediction can't
agree or disagree — including it in the denominator would drag the rate
down without representing actual model disagreement. `no_prediction_count`
is surfaced so the consumer can see how many such occurrences exist.

Optional rank threshold: when the caller passes
`?agreement_coarsest_rank=FAMILY`, the response also includes
`agreed_coarser_rank_*` counting only LCAs at that rank or deeper. The
threshold rank is echoed in `agreement_coarsest_rank`. When the param is
absent, the coarser-rank fields are null and `agreement_coarsest_rank`
is null.
"""

project_id = serializers.IntegerField()
total_occurrences = serializers.IntegerField()
verified_count = serializers.IntegerField(help_text="Occurrences with at least one non-withdrawn identification.")
verified_pct = serializers.FloatField(
min_value=0.0,
max_value=1.0,
help_text="verified_count / total_occurrences",
)
verified_with_prediction_count = serializers.IntegerField(
help_text="Verified occurrences that also have a machine prediction (denominator for agreed_*_pct)."
)
no_prediction_count = serializers.IntegerField(
help_text="Verified occurrences with no machine prediction (excluded from agreement denominator)."
)
agreed_exact_count = serializers.IntegerField()
agreed_exact_pct = serializers.FloatField(
min_value=0.0,
max_value=1.0,
help_text="agreed_exact_count / verified_with_prediction_count",
)
agreed_any_rank_count = serializers.IntegerField(
help_text="Exact matches plus disagreements whose LCA is at any real rank (UNKNOWN excluded)."
)
agreed_any_rank_pct = serializers.FloatField(
min_value=0.0,
max_value=1.0,
help_text="agreed_any_rank_count / verified_with_prediction_count",
)
agreement_coarsest_rank = serializers.CharField(
allow_null=True,
required=False,
help_text="Threshold rank from ?agreement_coarsest_rank query param. Null when the param is absent.",
)
agreed_coarser_rank_count = serializers.IntegerField(
allow_null=True,
required=False,
help_text=(
"Exact matches plus disagreements whose LCA is at `agreement_coarsest_rank` or deeper. "
"Null when no threshold was supplied."
),
)
agreed_coarser_rank_pct = serializers.FloatField(
min_value=0.0,
max_value=1.0,
allow_null=True,
required=False,
help_text="agreed_coarser_rank_count / verified_with_prediction_count. Null when no threshold supplied.",
)
87 changes: 70 additions & 17 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ami.base.views import ProjectMixin
from ami.main.api.schemas import limit_doc_param, project_id_doc_param
from ami.main.api.serializers import TagSerializer
from ami.main.models_future.occurrence import top_identifiers_for_project
from ami.main.models_future.occurrence import model_agreement_for_project, top_identifiers_for_project
from ami.utils.requests import get_default_classification_threshold
from ami.utils.storages import ConnectionTestResult

Expand All @@ -56,6 +56,7 @@
Tag,
TaxaList,
Taxon,
TaxonRank,
User,
update_detection_counts,
)
Expand All @@ -72,6 +73,7 @@
EventSerializer,
EventTimelineSerializer,
IdentificationSerializer,
ModelAgreementSerializer,
OccurrenceListSerializer,
OccurrenceSerializer,
PageListSerializer,
Expand Down Expand Up @@ -1168,6 +1170,24 @@ def filter_queryset(self, request, queryset, view):
return queryset


OCCURRENCE_FILTER_BACKENDS = (
CustomOccurrenceDeterminationFilter,
OccurrenceCollectionFilter,
OccurrenceAlgorithmFilter,
OccurrenceDateFilter,
OccurrenceVerified,
OccurrenceVerifiedByMeFilter,
OccurrenceTaxaListFilter,
)

OCCURRENCE_FILTERSET_FIELDS = (
"event",
"deployment",
"determination__rank",
"detections__source_image",
)


class OccurrenceViewSet(DefaultViewSet, ProjectMixin):
"""
API endpoint that allows occurrences to be viewed or edited.
Expand All @@ -1177,22 +1197,8 @@ class OccurrenceViewSet(DefaultViewSet, ProjectMixin):
queryset = Occurrence.objects.all()

serializer_class = OccurrenceSerializer
# filter_backends = [CustomDeterminationFilter, DjangoFilterBackend, NullsLastOrderingFilter, SearchFilter]
filter_backends = DefaultViewSetMixin.filter_backends + [
CustomOccurrenceDeterminationFilter,
OccurrenceCollectionFilter,
OccurrenceAlgorithmFilter,
OccurrenceDateFilter,
OccurrenceVerified,
OccurrenceVerifiedByMeFilter,
OccurrenceTaxaListFilter,
]
filterset_fields = [
"event",
"deployment",
"determination__rank",
"detections__source_image",
]
filter_backends = DefaultViewSetMixin.filter_backends + list(OCCURRENCE_FILTER_BACKENDS)
filterset_fields = list(OCCURRENCE_FILTERSET_FIELDS)
ordering_fields = [
"created_at",
"updated_at",
Expand Down Expand Up @@ -1290,6 +1296,11 @@ class OccurrenceStatsViewSet(viewsets.GenericViewSet, ProjectMixin):

permission_classes = [IsActiveStaffOrReadOnly]
require_project = True
# Filter machinery for actions that opt into `self.filter_queryset(...)`.
# `top_identifiers` doesn't call it, so its behavior is unchanged.
queryset = Occurrence.objects.none()
filter_backends = [DjangoFilterBackend, *OCCURRENCE_FILTER_BACKENDS]
filterset_fields = list(OCCURRENCE_FILTERSET_FIELDS)

Comment thread
mihow marked this conversation as resolved.
@extend_schema(
parameters=[project_id_doc_param, limit_doc_param],
Expand Down Expand Up @@ -1320,6 +1331,48 @@ def top_identifiers(self, request):
)
return Response(serializer.data)

@extend_schema(
parameters=[project_id_doc_param],
responses=ModelAgreementSerializer,
)
@action(detail=False, methods=["get"], url_path="model-agreement")
def model_agreement(self, request):
"""Verified / human↔model agreement rates over the filtered occurrence set.

Accepts every query param the `/occurrences/` list endpoint accepts.
Reuses `apply_default_filters` so `apply_defaults=false` bypasses
project default taxa lists + score thresholds.

Optional ?agreement_coarsest_rank=<RANK> adds `agreed_coarser_rank_*`
counts — LCAs at the given rank or deeper. Valid values: any
TaxonRank name (FAMILY, GENUS, etc.); invalid → 400.
"""
project = self.get_active_project()
assert project is not None # require_project=True guarantees this
if not Project.objects.visible_for_user(request.user).filter(pk=project.pk).exists():
raise NotFound("Project not found.")

coarsest_rank_param = request.query_params.get("agreement_coarsest_rank")
coarsest_rank = None
if coarsest_rank_param:
try:
coarsest_rank = TaxonRank[coarsest_rank_param.upper()]
except KeyError:
valid = ", ".join(r.name for r in TaxonRank if r.name != "UNKNOWN")
raise api_exceptions.ValidationError(
{"agreement_coarsest_rank": f"Invalid rank '{coarsest_rank_param}'. Must be one of: {valid}."}
)
if coarsest_rank == TaxonRank.UNKNOWN:
raise api_exceptions.ValidationError(
{"agreement_coarsest_rank": "UNKNOWN is not a valid threshold rank."}
)

base_qs = Occurrence.objects.filter(project=project).valid().apply_default_filters(project, request)
filtered_qs = self.filter_queryset(base_qs)
payload = model_agreement_for_project(filtered_qs, coarsest_rank=coarsest_rank)
payload["project_id"] = project.pk
return Response(ModelAgreementSerializer(payload, context={"request": request}).data)


class TaxonTaxaListFilter(filters.BaseFilterBackend):
"""
Expand Down
168 changes: 166 additions & 2 deletions ami/main/models_future/occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,43 @@

from typing import TYPE_CHECKING

from django.db.models import Count, Prefetch, Q, QuerySet
from django.db.models import Count, OuterRef, Prefetch, Q, QuerySet, Subquery

from ami.main.models import Project, User
from ami.main.models import Project, TaxonRank, User

if TYPE_CHECKING:
from ami.main.models import Classification, Identification, Occurrence

TaxonTuple = tuple[int, str, list[dict]]


def lca_rank_between(a: TaxonTuple, b: TaxonTuple) -> TaxonRank | None:
"""Most-specific shared ancestor rank between two taxa.

Inputs are ``(taxon_id, rank_str, parents_json)`` triples where
``parents_json`` is ordered root → immediate parent (Taxon.parents_json layout).

The taxon itself counts as part of its own ancestor chain — passing the
same taxon twice returns that taxon's rank. Returns ``None`` when the two
chains share no ancestor at a real taxonomic rank.

``TaxonRank.UNKNOWN`` is excluded from the candidate set even though it
sorts after SPECIES in OrderedEnum definition order — it isn't a real
taxonomic rank and treating it as deeper-than-ORDER produces false
under-order agreements when an UNKNOWN ancestor happens to be shared.
"""
chain_a = [(p["id"], TaxonRank(p["rank"])) for p in a[2]] + [(a[0], TaxonRank(a[1]))]
chain_b_ids = {p["id"] for p in b[2]} | {b[0]}

deepest: TaxonRank | None = None
for tid, rank in chain_a:
if rank == TaxonRank.UNKNOWN:
continue
if tid in chain_b_ids:
if deepest is None or rank > deepest:
deepest = rank
return deepest


def _detections_prefetch(*, ordering: tuple[str, ...], with_source_image: bool) -> Prefetch:
from ami.main.models import Classification, Detection
Expand Down Expand Up @@ -133,6 +163,140 @@ def detection_image_urls_from_prefetch(occurrence: Occurrence, limit: int | None
return [get_media_url(det.path) for det in detections]


def model_agreement_for_project(
queryset: QuerySet[Occurrence],
coarsest_rank: TaxonRank | None = None,
) -> dict:
"""Verified / agreement stats over a pre-filtered Occurrence queryset.

The queryset MUST already be filtered to the project + user-supplied
filters (caller wires apply_default_filters + OccurrenceFilter). This
function adds the annotations it needs and returns a dict matching
ModelAgreementSerializer's field set (without project_id — the view
layer adds that).

"Verified" means the occurrence has at least one non-withdrawn
Identification. "Model prediction" means the Classification chosen by
BEST_MACHINE_PREDICTION_ORDER. "Any-rank" agreement means the user's
taxon and the model's prediction share an ancestor at any real rank
(UNKNOWN excluded) — exact matches included. The upstream filter (e.g.
a Lepidoptera include list) is what bounds the meaningful scope, not
a hardcoded rank threshold in this function.

When ``coarsest_rank`` is supplied, additionally compute "coarser-rank"
agreement: the LCA must be at ``coarsest_rank`` or deeper (e.g. passing
FAMILY only counts LCAs at FAMILY, GENUS, or SPECIES). Exact matches
always count regardless of rank.

Performance: the heavy work — correlated subqueries over Identification
and Classification — is scoped to the verified set, which is typically
a tiny fraction of total occurrences. Computing those subqueries over
the full filtered queryset would do 99% wasted work picking the "best
user identification" for occurrences that have none.

Step 1: total_occurrences = SQL Count(*).
Step 2: Fetch the verified set with (pk, best_user_taxon_id,
best_machine_prediction_taxon_id). Both correlated subqueries
evaluate only on verified rows.
Step 3: Bucket counts in Python (set is small).
Step 4: Dedupe disagreement to distinct (user, machine) pairs and run
one LCA per pair.

Bench against project 18 (43,149 occurrences, 45 verified): ~80ms cold.
"""
import collections

from ami.main.models import BEST_IDENTIFICATION_ORDER, Identification, Taxon

# Default filters can join Identification (verified_by_me) and Taxon
# parents_json (taxa_list_id) which inflates row count if not deduped.
# Dedupe up front so total + verified counts share one canonical set.
queryset = queryset.distinct()
total = queryset.count()

best_user_ident = Identification.objects.filter(occurrence=OuterRef("pk"), withdrawn=False).order_by(
*BEST_IDENTIFICATION_ORDER
)

verified_rows = list(
queryset.filter(identifications__withdrawn=False)
.distinct()
.with_best_machine_prediction() # type: ignore[attr-defined]
.annotate(best_user_taxon_id=Subquery(best_user_ident.values("taxon_id")[:1]))
.values("pk", "best_machine_prediction_taxon_id", "best_user_taxon_id")
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

verified = len(verified_rows)
no_prediction = sum(1 for r in verified_rows if r["best_machine_prediction_taxon_id"] is None)
verified_with_pred = verified - no_prediction
agreed_exact = sum(
1
for r in verified_rows
if r["best_machine_prediction_taxon_id"] is not None
and r["best_user_taxon_id"] == r["best_machine_prediction_taxon_id"]
)

# Dedupe disagreement pairs so each (user_taxon, machine_taxon) LCA runs once.
pair_counts: collections.Counter = collections.Counter()
for r in verified_rows:
m_id = r["best_machine_prediction_taxon_id"]
u_id = r["best_user_taxon_id"]
if m_id is None or u_id is None or u_id == m_id:
continue
pair_counts[(u_id, m_id)] += 1

needed_taxa_ids: set[int] = set()
for u_id, m_id in pair_counts:
needed_taxa_ids.add(u_id)
needed_taxa_ids.add(m_id)

taxa_by_id: dict[int, TaxonTuple] = {}
if needed_taxa_ids:
for t in Taxon.objects.filter(pk__in=needed_taxa_ids):
parents = [
{"id": p.id, "rank": p.rank.name if hasattr(p.rank, "name") else p.rank} for p in t.parents_json
]
taxa_by_id[t.pk] = (t.pk, t.rank, parents)

any_rank_disagreement_count = 0
coarser_rank_disagreement_count = 0
for (u_id, m_id), count in pair_counts.items():
u = taxa_by_id.get(u_id)
m = taxa_by_id.get(m_id)
if not u or not m:
continue
lca = lca_rank_between(u, m)
if lca is None:
continue
any_rank_disagreement_count += count
if coarsest_rank is not None and lca >= coarsest_rank:
coarser_rank_disagreement_count += count

agreed_any_rank = agreed_exact + any_rank_disagreement_count
agreed_coarser_rank = agreed_exact + coarser_rank_disagreement_count

def _pct(num: int, denom: int) -> float:
return round(num / denom, 4) if denom else 0.0

payload: dict = {
"total_occurrences": total,
"verified_count": verified,
"verified_pct": _pct(verified, total),
"verified_with_prediction_count": verified_with_pred,
"no_prediction_count": no_prediction,
"agreed_exact_count": agreed_exact,
"agreed_exact_pct": _pct(agreed_exact, verified_with_pred),
"agreed_any_rank_count": agreed_any_rank,
"agreed_any_rank_pct": _pct(agreed_any_rank, verified_with_pred),
"agreement_coarsest_rank": coarsest_rank.name if coarsest_rank is not None else None,
"agreed_coarser_rank_count": agreed_coarser_rank if coarsest_rank is not None else None,
"agreed_coarser_rank_pct": (
_pct(agreed_coarser_rank, verified_with_pred) if coarsest_rank is not None else None
),
}
return payload


def top_identifiers_for_project(project: Project) -> QuerySet[User]:
"""Project users ranked by distinct occurrences they identified.

Expand Down
Loading
Loading