Skip to content

Commit d839c3e

Browse files
committed
fix: Use source id weighted
1 parent 1764ecd commit d839c3e

2 files changed

Lines changed: 85 additions & 12 deletions

File tree

backend/src/ref_backend/core/outliers.py

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,57 @@ def flag_outliers_iqr(values: Sequence[float], factor: float = 5.0, min_n: int =
3030
return [v < lower_bound or v > upper_bound for v in values]
3131

3232

33+
def calculate_iqr_bounds_by_source_id(
34+
df: pd.DataFrame, factor: float = 3.0, min_n: int = 4
35+
) -> tuple[float, float] | None:
36+
"""
37+
Calculate IQR bounds using source_id means for equal model weighting.
38+
39+
This function calculates mean value for each source_id and then
40+
computes IQR bounds on these means, ensuring each model gets equal
41+
weight regardless of number of ensemble members.
42+
43+
Parameters
44+
----------
45+
df : pd.DataFrame
46+
DataFrame containing scalar values with dimensions including source_id
47+
factor : float
48+
The factor to multiply IQR by to determine outlier bounds
49+
min_n : int
50+
Minimum number of source_ids required to perform outlier detection
51+
52+
Returns
53+
-------
54+
tuple[float, float] | None
55+
Tuple of (lower_bound, upper_bound) or None if insufficient data
56+
"""
57+
# Check if source_id column exists
58+
if "source_id" not in df.columns:
59+
return None
60+
61+
# Separate Reference values (exclude from IQR calculation)
62+
reference_mask = df["source_id"] == "Reference"
63+
non_reference_df = df[~reference_mask]
64+
65+
# Group by source_id and calculate mean for each
66+
source_id_means = non_reference_df.groupby("source_id")["value"].mean()
67+
68+
# Check if we have enough source_ids for outlier detection
69+
if len(source_id_means) < min_n:
70+
return None
71+
72+
# Calculate IQR on source_id means
73+
means_list = source_id_means.tolist()
74+
quantiles = statistics.quantiles(means_list, n=4, method="inclusive")
75+
q1, q3 = quantiles[0], quantiles[2]
76+
iqr = q3 - q1
77+
78+
lower_bound = q1 - factor * iqr
79+
upper_bound = q3 + factor * iqr
80+
81+
return lower_bound, upper_bound
82+
83+
3384
def detect_outliers_in_scalar_values(
3485
scalar_values: Sequence[models.ScalarMetricValue],
3586
factor: float = 3.0,
@@ -38,6 +89,11 @@ def detect_outliers_in_scalar_values(
3889
) -> tuple[list[AnnotatedScalarValue], int]:
3990
"""Detect outliers in scalar metric values grouped by stable diagnostic facets.
4091
92+
This function uses source_id-aware outlier detection, where IQR bounds are calculated
93+
using the mean value of each source_id rather than on all individual ensemble members.
94+
This ensures each model gets equal weight regardless of number of ensemble members.
95+
The calculated bounds are then applied to individual values for outlier detection.
96+
4197
Parameters
4298
----------
4399
scalar_values
@@ -49,7 +105,7 @@ def detect_outliers_in_scalar_values(
49105
50106
Defaults to 3.0.
51107
min_n
52-
The minimum number of data points required in a group to perform
108+
The minimum number of source_ids required in a group to perform
53109
IQR outlier detection. Defaults to 4.
54110
group_by
55111
A sequence of dimension names to group the `scalar_values` by before
@@ -72,20 +128,37 @@ def detect_outliers_in_scalar_values(
72128
group_by = [g for g in group_by if g in df.columns]
73129

74130
for _, group_values in df.groupby(list(group_by)):
75-
print(group_values)
76131
# Identify non-finite values (NaN, inf) as outliers
77132
finite_flags = group_values.value.apply(
78133
lambda x: isinstance(x, int | float) and not math.isinf(x) and not math.isnan(x)
79134
)
80-
# Apply IQR only if group has enough values
81-
if len(group_values) >= min_n:
82-
iqr_flags = flag_outliers_iqr(group_values.value.to_list(), factor=factor)
135+
# Apply source_id-aware outlier detection if source_id exists
136+
if "source_id" in group_values.columns and len(group_values) >= min_n:
137+
iqr_bounds = calculate_iqr_bounds_by_source_id(group_values, factor=factor, min_n=min_n)
138+
139+
if iqr_bounds is not None:
140+
lower_bound, upper_bound = iqr_bounds
141+
# Apply bounds to individual values (Reference values always non-outlier)
142+
source_id_flags = group_values.apply(
143+
lambda row: (row["value"] < lower_bound or row["value"] > upper_bound)
144+
if row["source_id"] != "Reference"
145+
else False,
146+
axis=1,
147+
)
148+
else:
149+
# Fallback if insufficient source_ids
150+
source_id_flags = [False] * len(group_values) # type: ignore
83151
else:
84-
iqr_flags = [False] * len(group_values)
85-
86-
# Combine flags: item is outlier if iqr-flagged
87-
for sv, is_outside_iqr, is_finite in zip(group_values.scalar_value, iqr_flags, finite_flags):
88-
is_outlier = is_outside_iqr or not is_finite
152+
# Fallback to original IQR method if no source_id or insufficient data
153+
if len(group_values) >= min_n:
154+
iqr_flags = flag_outliers_iqr(group_values.value.to_list(), factor=factor)
155+
else:
156+
iqr_flags = [False] * len(group_values)
157+
source_id_flags = iqr_flags # type: ignore
158+
159+
# Combine flags: item is outlier if flagged by source_id method OR non-finite
160+
for sv, is_source_outlier, is_finite in zip(group_values.scalar_value, source_id_flags, finite_flags):
161+
is_outlier = is_source_outlier or not is_finite
89162
verification_status: Literal["verified", "unverified"] = (
90163
"unverified" if is_outlier else "verified"
91164
)

backend/tests/test_core/test_outliers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class TestFlagOutliersIQR:
99
def test_outlier_detection_with_factor_3(self):
1010
"""Test outlier detection with factor=3.0 on controlled data."""
1111
values = [1, 2, 2, 2, 100]
12-
result = flag_outliers_iqr(values, factor=3.0)
12+
result = flag_outliers_iqr(values, factor=3.0, min_n=4)
1313
expected = [True, False, False, False, True]
1414
assert result == expected
1515

@@ -23,7 +23,7 @@ def test_outlier_detection_with_factor_1_5(self):
2323
- Any value outside [2, 2] is an outlier, so 100 is flagged
2424
"""
2525
values = [1, 2, 2, 2, 100]
26-
result = flag_outliers_iqr(values, factor=1.5)
26+
result = flag_outliers_iqr(values, factor=1.5, min_n=4)
2727
expected = [True, False, False, False, True]
2828
assert result == expected
2929

0 commit comments

Comments
 (0)