@@ -30,6 +30,57 @@ def flag_outliers_iqr(values: Sequence[float], factor: float = 5.0, min_n: int =
3030 return [v < lower_bound or v > upper_bound for v in values ]
3131
3232
33+ def calculate_iqr_bounds_by_source_id (
34+ df : pd .DataFrame , factor : float = 3.0 , min_n : int = 4
35+ ) -> tuple [float , float ] | None :
36+ """
37+ Calculate IQR bounds using source_id means for equal model weighting.
38+
39+ This function calculates mean value for each source_id and then
40+ computes IQR bounds on these means, ensuring each model gets equal
41+ weight regardless of number of ensemble members.
42+
43+ Parameters
44+ ----------
45+ df : pd.DataFrame
46+ DataFrame containing scalar values with dimensions including source_id
47+ factor : float
48+ The factor to multiply IQR by to determine outlier bounds
49+ min_n : int
50+ Minimum number of source_ids required to perform outlier detection
51+
52+ Returns
53+ -------
54+ tuple[float, float] | None
55+ Tuple of (lower_bound, upper_bound) or None if insufficient data
56+ """
57+ # Check if source_id column exists
58+ if "source_id" not in df .columns :
59+ return None
60+
61+ # Separate Reference values (exclude from IQR calculation)
62+ reference_mask = df ["source_id" ] == "Reference"
63+ non_reference_df = df [~ reference_mask ]
64+
65+ # Group by source_id and calculate mean for each
66+ source_id_means = non_reference_df .groupby ("source_id" )["value" ].mean ()
67+
68+ # Check if we have enough source_ids for outlier detection
69+ if len (source_id_means ) < min_n :
70+ return None
71+
72+ # Calculate IQR on source_id means
73+ means_list = source_id_means .tolist ()
74+ quantiles = statistics .quantiles (means_list , n = 4 , method = "inclusive" )
75+ q1 , q3 = quantiles [0 ], quantiles [2 ]
76+ iqr = q3 - q1
77+
78+ lower_bound = q1 - factor * iqr
79+ upper_bound = q3 + factor * iqr
80+
81+ return lower_bound , upper_bound
82+
83+
3384def detect_outliers_in_scalar_values (
3485 scalar_values : Sequence [models .ScalarMetricValue ],
3586 factor : float = 3.0 ,
@@ -38,6 +89,11 @@ def detect_outliers_in_scalar_values(
3889) -> tuple [list [AnnotatedScalarValue ], int ]:
3990 """Detect outliers in scalar metric values grouped by stable diagnostic facets.
4091
92+ This function uses source_id-aware outlier detection, where IQR bounds are calculated
93+ using the mean value of each source_id rather than on all individual ensemble members.
94+ This ensures each model gets equal weight regardless of number of ensemble members.
95+ The calculated bounds are then applied to individual values for outlier detection.
96+
4197 Parameters
4298 ----------
4399 scalar_values
@@ -49,7 +105,7 @@ def detect_outliers_in_scalar_values(
49105
50106 Defaults to 3.0.
51107 min_n
52- The minimum number of data points required in a group to perform
108+ The minimum number of source_ids required in a group to perform
53109 IQR outlier detection. Defaults to 4.
54110 group_by
55111 A sequence of dimension names to group the `scalar_values` by before
@@ -72,20 +128,37 @@ def detect_outliers_in_scalar_values(
72128 group_by = [g for g in group_by if g in df .columns ]
73129
74130 for _ , group_values in df .groupby (list (group_by )):
75- print (group_values )
76131 # Identify non-finite values (NaN, inf) as outliers
77132 finite_flags = group_values .value .apply (
78133 lambda x : isinstance (x , int | float ) and not math .isinf (x ) and not math .isnan (x )
79134 )
80- # Apply IQR only if group has enough values
81- if len (group_values ) >= min_n :
82- iqr_flags = flag_outliers_iqr (group_values .value .to_list (), factor = factor )
135+ # Apply source_id-aware outlier detection if source_id exists
136+ if "source_id" in group_values .columns and len (group_values ) >= min_n :
137+ iqr_bounds = calculate_iqr_bounds_by_source_id (group_values , factor = factor , min_n = min_n )
138+
139+ if iqr_bounds is not None :
140+ lower_bound , upper_bound = iqr_bounds
141+ # Apply bounds to individual values (Reference values always non-outlier)
142+ source_id_flags = group_values .apply (
143+ lambda row : (row ["value" ] < lower_bound or row ["value" ] > upper_bound )
144+ if row ["source_id" ] != "Reference"
145+ else False ,
146+ axis = 1 ,
147+ )
148+ else :
149+ # Fallback if insufficient source_ids
150+ source_id_flags = [False ] * len (group_values ) # type: ignore
83151 else :
84- iqr_flags = [False ] * len (group_values )
85-
86- # Combine flags: item is outlier if iqr-flagged
87- for sv , is_outside_iqr , is_finite in zip (group_values .scalar_value , iqr_flags , finite_flags ):
88- is_outlier = is_outside_iqr or not is_finite
152+ # Fallback to original IQR method if no source_id or insufficient data
153+ if len (group_values ) >= min_n :
154+ iqr_flags = flag_outliers_iqr (group_values .value .to_list (), factor = factor )
155+ else :
156+ iqr_flags = [False ] * len (group_values )
157+ source_id_flags = iqr_flags # type: ignore
158+
159+ # Combine flags: item is outlier if flagged by source_id method OR non-finite
160+ for sv , is_source_outlier , is_finite in zip (group_values .scalar_value , source_id_flags , finite_flags ):
161+ is_outlier = is_source_outlier or not is_finite
89162 verification_status : Literal ["verified" , "unverified" ] = (
90163 "unverified" if is_outlier else "verified"
91164 )
0 commit comments