Skip to content

Commit 0910972

Browse files
committed
distinguish numerical exog from categorical features in predict
- Add has_numerical_exog_ flag to track true numerical exog - Only require exog_future when has_numerical_exog_ is True - Categorical-only models can predict without exog_future - Add has_numerical_exog to save/load state with backward compat - Fixes ValueError when predicting with categorical-only models
1 parent 5f891c3 commit 0910972

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

apdtflow/forecaster.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
self.exog_mean_: Optional[np.ndarray] = None
123123
self.exog_std_: Optional[np.ndarray] = None
124124
self.has_exog_ = False
125+
self.has_numerical_exog_ = False # Track if we have true numerical exog (vs just categorical)
125126

126127
# Categorical variables (NEW in v0.2.3)
127128
self.categorical_encoding = categorical_encoding
@@ -542,6 +543,7 @@ def fit(
542543
self._has_combined_exog = True
543544
# Update num features to reflect combined numerical + categorical
544545
self.has_exog_ = True
546+
self.has_numerical_exog_ = True # Track that we have true numerical exog
545547
self.num_exog_features_ = combined_exog.shape[1]
546548
if self.verbose:
547549
print(f"Combined {len(exog_cols)} numerical + {num_categorical_features} categorical = {self.num_exog_features_} total exogenous features")
@@ -550,6 +552,7 @@ def fit(
550552
self._combined_exog_data = categorical_encoded
551553
self._has_combined_exog = True
552554
self.has_exog_ = True
555+
self.has_numerical_exog_ = False # Only categorical, no numerical exog
553556
self.num_exog_features_ = categorical_encoded.shape[1]
554557
if self.verbose:
555558
print(f"Using {self.num_exog_features_} categorical features as exogenous variables")
@@ -564,6 +567,7 @@ def fit(
564567
f"Current model_type='{self.model_type}' does not support exog_cols."
565568
)
566569
self.has_exog_ = True
570+
self.has_numerical_exog_ = True # True numerical exogenous variables
567571
self.num_exog_features_ = len(exog_cols)
568572
if self.verbose:
569573
print(f"Using {self.num_exog_features_} exogenous features: {exog_cols}")
@@ -869,10 +873,11 @@ def predict(
869873
)
870874
steps = self.forecast_horizon
871875

872-
# Check exog requirements
873-
if self.has_exog_ and exog_future is None:
876+
# Check exog requirements - only require exog_future for numerical exog
877+
# Categorical features can be auto-generated
878+
if self.has_numerical_exog_ and exog_future is None:
874879
raise ValueError(
875-
"Model was trained with exogenous variables. "
880+
"Model was trained with numerical exogenous variables. "
876881
"Please provide exog_future parameter for prediction."
877882
)
878883

@@ -1500,6 +1505,7 @@ def save(self, filepath: str):
15001505
'exog_mean': self.exog_mean_,
15011506
'exog_std': self.exog_std_,
15021507
'has_exog': self.has_exog_,
1508+
'has_numerical_exog': self.has_numerical_exog_,
15031509

15041510
# Categorical variables state (NEW in v0.2.3)
15051511
'categorical_encoding': self.categorical_encoding,
@@ -1578,6 +1584,7 @@ def load(cls, filepath: str, device: Optional[str] = None):
15781584
model.exog_mean_ = state['exog_mean']
15791585
model.exog_std_ = state['exog_std']
15801586
model.has_exog_ = state['has_exog']
1587+
model.has_numerical_exog_ = state.get('has_numerical_exog', state['has_exog']) # Backward compat
15811588

15821589
# Restore conformal predictor from saved state
15831590
conformal_state = state.get('conformal_state', None)

0 commit comments

Comments
 (0)