1818
1919import hashlib
2020import logging
21+ from pathlib import Path
22+ from typing import Any
2123
2224import pandas as pd
2325
@@ -41,12 +43,13 @@ class InequalityHandler(ConstraintHandler):
4143
4244 _DATETIME_EPOCH = pd .Timestamp ("1970-01-01" ) # reference epoch for delta representation
4345
44- def __init__ (self , constraint : Inequality , table = None ):
46+ def __init__ (self , constraint : Inequality , table = None , workspace_dir : Path | None = None ):
4547 self .constraint = constraint
4648 self .table_name = constraint .table_name
4749 self .low_column = constraint .low_column
4850 self .high_column = constraint .high_column
4951 self ._delta_column = _generate_internal_column_name ("INEQ_DELTA" , [self .low_column , self .high_column ])
52+ self .workspace_dir = workspace_dir
5053
5154 # determine if this is a datetime constraint based on table encoding types
5255 self ._is_datetime = False
@@ -134,12 +137,63 @@ def to_original(self, df: pd.DataFrame) -> pd.DataFrame:
134137 violations = both_valid_for_check & (high < low )
135138 df .loc [violations , self .high_column ] = low [violations ]
136139
140+ # clip to training data bounds
141+ if self .workspace_dir is not None :
142+ self ._clip_to_training_bounds (df )
143+
137144 # convert back to original dtype
138145 if pd .api .types .is_integer_dtype (high_dtype ):
139146 df [self .high_column ] = df [self .high_column ].astype (high_dtype )
140147
141148 return df .drop (columns = [self ._delta_column ])
142149
150+ def _extract_min_max_from_stats (self , col_stats : dict ) -> tuple [Any , Any ]:
151+ """extract min/max from column stats, handling all encoding types (same pattern as parse_min_max)."""
152+ # try bins/min5/max5 arrays first (for binned/digit/datetime encoding)
153+ values = col_stats .get ("bins" , []) + col_stats .get ("min5" , []) + col_stats .get ("max5" , [])
154+ if values :
155+ return min (values ), max (values )
156+ # fall back to direct min/max (for other encoding types or when arrays are empty)
157+ return col_stats .get ("min" ), col_stats .get ("max" )
158+
159+ def _clip_to_training_bounds (self , df : pd .DataFrame ) -> None :
160+ """clip high column values to min/max from training data stats."""
161+ from mostlyai .engine ._workspace import Workspace
162+
163+ workspace = Workspace (self .workspace_dir )
164+ tgt_stats = workspace .tgt_stats .read ()
165+ if not tgt_stats or self .high_column not in tgt_stats .get ("columns" , {}):
166+ return
167+
168+ col_stats = tgt_stats ["columns" ][self .high_column ]
169+ min_val , max_val = self ._extract_min_max_from_stats (col_stats )
170+ if min_val is None and max_val is None :
171+ return
172+
173+ high = df [self .high_column ]
174+ low = df [self .low_column ]
175+ if self ._is_datetime :
176+ if min_val is not None :
177+ min_val = pd .to_datetime (min_val )
178+ df .loc [high .notna () & (high < min_val ), self .high_column ] = min_val
179+ if max_val is not None :
180+ max_val = pd .to_datetime (max_val )
181+ df .loc [high .notna () & (high > max_val ), self .high_column ] = max_val
182+ else :
183+ if min_val is not None :
184+ min_val = float (min_val )
185+ df .loc [high .notna () & (high < min_val ), self .high_column ] = min_val
186+ if max_val is not None :
187+ max_val = float (max_val )
188+ df .loc [high .notna () & (high > max_val ), self .high_column ] = max_val
189+
190+ # ensure constraint is still satisfied after clipping
191+ # if clipping made high < low, set high = low
192+ high = df [self .high_column ]
193+ both_valid = low .notna () & high .notna ()
194+ violations = both_valid & (high < low )
195+ df .loc [violations , self .high_column ] = low [violations ]
196+
143197 def get_encoding_types (self ) -> dict [str , str ]:
144198 # use TABULAR_DATETIME for datetime constraints to preserve precision
145199 # use TABULAR_NUMERIC_AUTO for numeric constraints
0 commit comments