diff --git a/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py b/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py index dd957d1..fb4973e 100644 --- a/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py +++ b/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor: """ # Check input shape and expand if necessary if x.dim() == 3: # Case: (B, L, D) - no ensembles - batch_size, seq_len, input_size = x.shape + batch_size, seq_len, _ = x.shape # Shape: (B, L, ensemble_size, D) x = x.unsqueeze(2).expand(-1, -1, self.ensemble_size, -1) elif x.dim() == 4 and x.size(2) == self.ensemble_size: # Case: (B, L, ensemble_size, D) @@ -451,7 +451,7 @@ def forward(self, query, key, value, mask=None): If the ensemble size `E` does not match `self.ensemble_size`. """ - N, S, E, D = query.size() + N, S, E, _ = query.size() if E != self.ensemble_size: raise ValueError("Ensemble size mismatch.") diff --git a/deeptab/arch_utils/lstm_utils.py b/deeptab/arch_utils/lstm_utils.py index 72514eb..b04a0a7 100644 --- a/deeptab/arch_utils/lstm_utils.py +++ b/deeptab/arch_utils/lstm_utils.py @@ -128,7 +128,7 @@ def forward(self, x): """ if x.ndim != 3: raise ValueError("Input tensor must have 3 dimensions (batch, sequence_length, input_size)") - B, N, D = x.shape + B, N, _ = x.shape device = x.device # Initialize states dynamically based on input shape @@ -293,7 +293,7 @@ def forward(self, x): torch.Tensor Output tensor of shape (batch, sequence_length, input_size). """ - B, N, D = x.shape + B, N, _ = x.shape device = x.device # Initialize states dynamically based on input shape diff --git a/deeptab/arch_utils/mamba_utils/mamba_arch.py b/deeptab/arch_utils/mamba_utils/mamba_arch.py index ab9c4b5..826ece5 100644 --- a/deeptab/arch_utils/mamba_utils/mamba_arch.py +++ b/deeptab/arch_utils/mamba_utils/mamba_arch.py @@ -437,7 +437,7 @@ def forward(self, x): if self.bidirectional: xz_bwd = self.in_proj_bwd(x) - x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1) + x_bwd, _ = xz_bwd.chunk(2, dim=-1) x_bwd = x_bwd.transpose(1, 2) x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L] diff --git a/deeptab/arch_utils/transformer_utils.py b/deeptab/arch_utils/transformer_utils.py index d25d2f0..3d5eb12 100644 --- a/deeptab/arch_utils/transformer_utils.py +++ b/deeptab/arch_utils/transformer_utils.py @@ -329,11 +329,10 @@ def forward(self, x, mask: torch.Tensor = None): # type: ignore The output tensor of shape (N, S, E, D). """ if x.dim() == 3: # Case: (B, L, D) - no ensembles - batch_size, seq_len, input_size = x.shape # Shape: (B, L, ensemble_size, D) x = x.unsqueeze(2).expand(-1, -1, self.ensemble_size, -1) elif x.dim() == 4 and x.size(2) == self.ensemble_size: # Case: (B, L, ensemble_size, D) - batch_size, seq_len, ensemble_size, _ = x.shape + _, _, ensemble_size, _ = x.shape if ensemble_size != self.ensemble_size: raise ValueError(f"Input shape {x.shape} is invalid. Expected shape: (B, S, ensemble_size, N)") else: @@ -425,7 +424,7 @@ def forward(self, x): x: Input embeddings of shape (N, J, D), where N = batch size, J = number of features, D = embedding dimension. """ - _, n, d = x.shape + _, n, _ = x.shape for attn1, ff1, attn2, ff2 in self.layers: # type: ignore # Column-wise attention diff --git a/deeptab/base_models/autoint.py b/deeptab/base_models/autoint.py index a8a6b97..fa25996 100644 --- a/deeptab/base_models/autoint.py +++ b/deeptab/base_models/autoint.py @@ -162,9 +162,6 @@ def forward(self, *data): # Apply normalization before attention if prenormalization is enabled x_residual = layer["norm0"](x_residual) # type: ignore[index] - # Retrieve key-value compression layers - key_compression, value_compression = self._get_kv_compressions(layer) - # Multihead Attention x_residual, _ = layer["attention"](x_residual, x_residual, x_residual) # type: ignore[index] diff --git a/deeptab/base_models/ndtf.py b/deeptab/base_models/ndtf.py index 483b19e..e061482 100644 --- a/deeptab/base_models/ndtf.py +++ b/deeptab/base_models/ndtf.py @@ -123,9 +123,9 @@ def forward(self, *data) -> torch.Tensor: tree_input = x[:, : self.input_dimensions[idx]] preds.append(tree(tree_input, return_penalty=False)) - preds = torch.stack(preds, dim=1).squeeze(-1) - - return preds @ self.tree_weights + preds = torch.stack(preds, dim=1) # (batch, n_ensembles, output_dim) + # Weighted sum over ensemble dim: (batch, output_dim, n_ensembles) @ (n_ensembles, 1) + return (preds.transpose(1, 2) @ self.tree_weights).squeeze(-1) def penalty_forward(self, *data) -> torch.Tensor: """Forward pass of the NDTF model. @@ -158,5 +158,6 @@ def penalty_forward(self, *data) -> torch.Tensor: penalty += pen # Stack predictions and calculate mean across trees - preds = torch.stack(preds, dim=1).squeeze(-1) - return preds @ self.tree_weights, self.hparams.penalty_factor * penalty # type: ignore + preds = torch.stack(preds, dim=1) # (batch, n_ensembles, output_dim) + # Weighted sum over ensemble dim: (batch, output_dim, n_ensembles) @ (n_ensembles, 1) + return (preds.transpose(1, 2) @ self.tree_weights).squeeze(-1), self.hparams.penalty_factor * penalty # type: ignore diff --git a/deeptab/base_models/tabr.py b/deeptab/base_models/tabr.py index 4f327ca..187ff06 100644 --- a/deeptab/base_models/tabr.py +++ b/deeptab/base_models/tabr.py @@ -21,10 +21,11 @@ def __init__( self, feature_information: tuple, num_classes=1, + lss: bool = False, config: DefaultTabRConfig = DefaultTabRConfig(), # noqa: B008 **kwargs, ): - super().__init__(config=config, **kwargs) + super().__init__(config=config, lss=lss, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) # lazy import @@ -94,7 +95,7 @@ def make_block(prenorm: bool) -> nn.Sequential: delu = TabR.delu self.label_encoder = ( nn.Linear(1, d_main) - if num_classes == 1 + if num_classes == 1 or lss else nn.Sequential( nn.Embedding(num_classes, d_main), # gives depreciation warning @@ -278,10 +279,10 @@ def train_with_candidates(self, *data, targets, candidate_x, candidate_y): probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - if self.hparams.num_classes > 1: # for classification + if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) - else: # for regression - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + else: # for regression or LSS + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] @@ -324,7 +325,7 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): candidate_x, candidate_k = self._encode(candidate_x) x, k = self._encode(x) # encoded x and k - batch_size, d_main = k.shape + _, d_main = k.shape device = k.device context_size = self.context_size @@ -338,9 +339,8 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): # Updating the index is much faster than creating a new one. self.search_index.reset() self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code] - distances: Tensor context_idx: Tensor - distances, context_idx = self.search_index.search( # type: ignore[code] + _, context_idx = self.search_index.search( # type: ignore[code] k.to(torch.float32), context_size ) @@ -353,10 +353,10 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - if self.hparams.num_classes > 1: # for classification + if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] @@ -398,7 +398,7 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): candidate_x, candidate_k = self._encode(candidate_x) x, k = self._encode(x) # encoded x and k - batch_size, d_main = k.shape + _, d_main = k.shape device = k.device context_size = self.context_size @@ -412,9 +412,8 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): # Updating the index is much faster than creating a new one. self.search_index.reset() self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code] - distances: Tensor context_idx: Tensor - distances, context_idx = self.search_index.search( # type: ignore[code] + _, context_idx = self.search_index.search( # type: ignore[code] k.to(torch.float32), context_size ) @@ -427,10 +426,10 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - if self.hparams.num_classes > 1: # for classification + if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] diff --git a/deeptab/base_models/tabtransformer.py b/deeptab/base_models/tabtransformer.py index 53e7ac2..9446904 100644 --- a/deeptab/base_models/tabtransformer.py +++ b/deeptab/base_models/tabtransformer.py @@ -96,8 +96,11 @@ def __init__( mlp_input_dim = 0 for feature_name, info in num_feature_info.items(): mlp_input_dim += info["dimension"] + num_input_dim = mlp_input_dim # save before adding d_model mlp_input_dim += self.hparams.d_model + self.num_norm = nn.LayerNorm(num_input_dim) + self.tabular_head = MLPhead( input_dim=mlp_input_dim, config=config, @@ -125,13 +128,13 @@ def forward(self, *data): cat_embeddings = self.embedding_layer(*(None, cat_features, emb_features)) num_features = torch.cat(num_features, dim=1) - num_embeddings = self.norm_f(num_features) # type: ignore + num_features = self.num_norm(num_features) x = self.encoder(cat_embeddings) x = self.pool_sequence(x) - x = torch.cat((x, num_embeddings), axis=1) # type: ignore + x = torch.cat((x, num_features), axis=1) # type: ignore preds = self.tabular_head(x) return preds diff --git a/deeptab/base_models/utils/basemodel.py b/deeptab/base_models/utils/basemodel.py index 6837b73..d6d7e37 100644 --- a/deeptab/base_models/utils/basemodel.py +++ b/deeptab/base_models/utils/basemodel.py @@ -193,7 +193,7 @@ def pool_sequence(self, out): return out[:, 0, :] elif self.hparams.pooling_method == "learned_flatten": # Flatten sequence and apply a learned linear layer - batch_size, seq_len, hidden_size = out.shape + batch_size, _, _ = out.shape # Shape: (batch_size, seq_len * hidden_size) out = out.reshape(batch_size, -1) # Shape: (batch_size, hidden_size) diff --git a/deeptab/base_models/utils/lightning_wrapper.py b/deeptab/base_models/utils/lightning_wrapper.py index ae6125d..9a9c390 100644 --- a/deeptab/base_models/utils/lightning_wrapper.py +++ b/deeptab/base_models/utils/lightning_wrapper.py @@ -77,7 +77,7 @@ def __init__( else: self.loss_fct = nn.MSELoss() - self.save_hyperparameters(ignore=["model_class", "loss_fn"]) + self.save_hyperparameters(ignore=["model_class", "loss_fn", "family"]) self.lr = self.hparams.get("lr", config.lr) self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) @@ -93,6 +93,7 @@ def __init__( config=config, feature_information=feature_information, num_classes=output_dim, + lss=lss, **kwargs, ) @@ -176,7 +177,7 @@ def compute_loss(self, predictions, y_true): if getattr(self.estimator, "returns_ensemble", False): # Ensemble case if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3: # Classification case with ensemble: predictions (N, E, k), y_true (N,) - N, E, k = predictions.shape + _, E, _ = predictions.shape loss = 0.0 for ensemble_member in range(E): loss += self.loss_fct( @@ -595,7 +596,7 @@ def contrastive_loss(self, embeddings, knn_indices, temperature=0.1): Tensor Contrastive loss value. """ - N, S, D = embeddings.shape # Batch size, sequence length, embedding dim + _, S, D = embeddings.shape # Batch size, sequence length, embedding dim k_neighbors = knn_indices.shape[1] # Number of neighbors # Normalize embeddings diff --git a/deeptab/configs/__init__.py b/deeptab/configs/__init__.py index 63cf376..287e358 100644 --- a/deeptab/configs/__init__.py +++ b/deeptab/configs/__init__.py @@ -12,6 +12,7 @@ from .resnet_config import DefaultResNetConfig from .saint_config import DefaultSAINTConfig from .tabm_config import DefaultTabMConfig +from .tabr_config import DefaultTabRConfig from .tabtransformer_config import DefaultTabTransformerConfig from .tabularnn_config import DefaultTabulaRNNConfig from .tangos_config import DefaultTangosConfig @@ -32,6 +33,7 @@ "DefaultResNetConfig", "DefaultSAINTConfig", "DefaultTabMConfig", + "DefaultTabRConfig", "DefaultTabTransformerConfig", "DefaultTabulaRNNConfig", "DefaultTangosConfig", diff --git a/deeptab/models/__init__.py b/deeptab/models/__init__.py index ebffddc..48838d2 100644 --- a/deeptab/models/__init__.py +++ b/deeptab/models/__init__.py @@ -1,3 +1,6 @@ +import importlib +import warnings + from .autoint import AutoIntClassifier, AutoIntLSS, AutoIntRegressor from .enode import ENODELSS, ENODEClassifier, ENODERegressor from .fttransformer import ( @@ -13,20 +16,18 @@ ) from .mambular import MambularClassifier, MambularLSS, MambularRegressor from .mlp import MLPLSS, MLPClassifier, MLPRegressor -from .modern_nca import ModernNCAClassifier, ModernNCALSS, ModernNCARegressor from .ndtf import NDTFLSS, NDTFClassifier, NDTFRegressor from .node import NODELSS, NODEClassifier, NODERegressor from .resnet import ResNetClassifier, ResNetLSS, ResNetRegressor from .saint import SAINTLSS, SAINTClassifier, SAINTRegressor from .tabm import TabMClassifier, TabMLSS, TabMRegressor +from .tabr import TabRClassifier, TabRLSS, TabRRegressor from .tabtransformer import ( TabTransformerClassifier, TabTransformerLSS, TabTransformerRegressor, ) from .tabularnn import TabulaRNNClassifier, TabulaRNNLSS, TabulaRNNRegressor -from .tangos import TangosClassifier, TangosLSS, TangosRegressor -from .trompt import TromptClassifier, TromptLSS, TromptRegressor from .utils.sklearn_base_classifier import SklearnBaseClassifier from .utils.sklearn_base_lss import SklearnBaseLSS from .utils.sklearn_base_regressor import SklearnBaseRegressor @@ -56,9 +57,6 @@ "MambularClassifier", "MambularLSS", "MambularRegressor", - "ModernNCAClassifier", - "ModernNCALSS", - "ModernNCARegressor", "NDTFClassifier", "NDTFRegressor", "NODEClassifier", @@ -74,16 +72,42 @@ "TabMClassifier", "TabMLSS", "TabMRegressor", + "TabRClassifier", + "TabRLSS", + "TabRRegressor", "TabTransformerClassifier", "TabTransformerLSS", "TabTransformerRegressor", "TabulaRNNClassifier", "TabulaRNNLSS", "TabulaRNNRegressor", - "TangosClassifier", - "TangosLSS", - "TangosRegressor", - "TromptClassifier", - "TromptLSS", - "TromptRegressor", ] + +# --------------------------------------------------------------------------- +# Backwards-compatibility shim for experimental models +# --------------------------------------------------------------------------- + +_EXPERIMENTAL_COMPAT: dict[str, str] = { + "ModernNCAClassifier": "deeptab.models.experimental", + "ModernNCALSS": "deeptab.models.experimental", + "ModernNCARegressor": "deeptab.models.experimental", + "TangosClassifier": "deeptab.models.experimental", + "TangosLSS": "deeptab.models.experimental", + "TangosRegressor": "deeptab.models.experimental", + "TromptClassifier": "deeptab.models.experimental", + "TromptLSS": "deeptab.models.experimental", + "TromptRegressor": "deeptab.models.experimental", +} + + +def __getattr__(name: str): + if name in _EXPERIMENTAL_COMPAT: + new_path = _EXPERIMENTAL_COMPAT[name] + warnings.warn( + f"{name!r} has moved to '{new_path}'. Update your import: from {new_path} import {name}", + DeprecationWarning, + stacklevel=2, + ) + mod = importlib.import_module(new_path) + return getattr(mod, name) + raise AttributeError(f"module 'deeptab.models' has no attribute {name!r}") diff --git a/deeptab/models/_registry.py b/deeptab/models/_registry.py new file mode 100644 index 0000000..df2703c --- /dev/null +++ b/deeptab/models/_registry.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import Literal + +ModelStatus = Literal["stable", "experimental"] + + +@dataclass(frozen=True) +class ModelInfo: + name: str + status: ModelStatus + import_path: str + + +MODEL_REGISTRY: dict[str, ModelInfo] = { + "Mambular": ModelInfo("Mambular", "stable", "deeptab.models"), + "TabM": ModelInfo("TabM", "stable", "deeptab.models"), + "NODE": ModelInfo("NODE", "stable", "deeptab.models"), + "ENODE": ModelInfo("ENODE", "stable", "deeptab.models"), + "FTTransformer": ModelInfo("FTTransformer", "stable", "deeptab.models"), + "MLP": ModelInfo("MLP", "stable", "deeptab.models"), + "ResNet": ModelInfo("ResNet", "stable", "deeptab.models"), + "TabTransformer": ModelInfo("TabTransformer", "stable", "deeptab.models"), + "MambaTab": ModelInfo("MambaTab", "stable", "deeptab.models"), + "TabulaRNN": ModelInfo("TabulaRNN", "stable", "deeptab.models"), + "MambAttention": ModelInfo("MambAttention", "stable", "deeptab.models"), + "NDTF": ModelInfo("NDTF", "stable", "deeptab.models"), + "SAINT": ModelInfo("SAINT", "stable", "deeptab.models"), + "AutoInt": ModelInfo("AutoInt", "stable", "deeptab.models"), + "TabR": ModelInfo("TabR", "stable", "deeptab.models"), + "ModernNCA": ModelInfo("ModernNCA", "experimental", "deeptab.models.experimental"), + "Tangos": ModelInfo("Tangos", "experimental", "deeptab.models.experimental"), + "Trompt": ModelInfo("Trompt", "experimental", "deeptab.models.experimental"), +} diff --git a/deeptab/models/experimental/__init__.py b/deeptab/models/experimental/__init__.py new file mode 100644 index 0000000..59ad46b --- /dev/null +++ b/deeptab/models/experimental/__init__.py @@ -0,0 +1,23 @@ +""" +Experimental models — subject to change without notice. + +Import these explicitly to signal that you accept the instability: + + from deeptab.models.experimental import ModernNCA, Tangos, Trompt +""" + +from .modern_nca import ModernNCAClassifier, ModernNCALSS, ModernNCARegressor +from .tangos import TangosClassifier, TangosLSS, TangosRegressor +from .trompt import TromptClassifier, TromptLSS, TromptRegressor + +__all__ = [ + "ModernNCAClassifier", + "ModernNCALSS", + "ModernNCARegressor", + "TangosClassifier", + "TangosLSS", + "TangosRegressor", + "TromptClassifier", + "TromptLSS", + "TromptRegressor", +] diff --git a/deeptab/models/modern_nca.py b/deeptab/models/experimental/modern_nca.py similarity index 78% rename from deeptab/models/modern_nca.py rename to deeptab/models/experimental/modern_nca.py index 4b78479..6530e18 100644 --- a/deeptab/models/modern_nca.py +++ b/deeptab/models/experimental/modern_nca.py @@ -1,9 +1,9 @@ -from ..base_models.modern_nca import ModernNCA -from ..configs.modernnca_config import DefaultModernNCAConfig -from ..utils.docstring_generator import generate_docstring -from .utils.sklearn_base_classifier import SklearnBaseClassifier -from .utils.sklearn_base_lss import SklearnBaseLSS -from .utils.sklearn_base_regressor import SklearnBaseRegressor +from ...base_models.modern_nca import ModernNCA +from ...configs.modernnca_config import DefaultModernNCAConfig +from ...utils.docstring_generator import generate_docstring +from ..utils.sklearn_base_classifier import SklearnBaseClassifier +from ..utils.sklearn_base_lss import SklearnBaseLSS +from ..utils.sklearn_base_regressor import SklearnBaseRegressor class ModernNCARegressor(SklearnBaseRegressor): @@ -14,7 +14,7 @@ class ModernNCARegressor(SklearnBaseRegressor): with the default ModernNCA configuration. """, examples=""" - >>> from deeptab.models import ModernNCARegressor + >>> from deeptab.models.experimental import ModernNCARegressor >>> model = ModernNCARegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -34,7 +34,7 @@ class ModernNCAClassifier(SklearnBaseClassifier): with the default ModernNCA configuration. """, examples=""" - >>> from deeptab.models import ModernNCAClassifier + >>> from deeptab.models.experimental import ModernNCAClassifier >>> model = ModernNCAClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -54,7 +54,7 @@ class ModernNCALSS(SklearnBaseLSS): with the default ModernNCA configuration. """, examples=""" - >>> from deeptab.models import ModernNCALSS + >>> from deeptab.models.experimental import ModernNCALSS >>> model = ModernNCALSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family='normal') >>> preds = model.predict(X_test) diff --git a/deeptab/models/tangos.py b/deeptab/models/experimental/tangos.py similarity index 78% rename from deeptab/models/tangos.py rename to deeptab/models/experimental/tangos.py index abbd437..9502791 100644 --- a/deeptab/models/tangos.py +++ b/deeptab/models/experimental/tangos.py @@ -1,9 +1,9 @@ -from ..base_models.tangos import Tangos -from ..configs.tangos_config import DefaultTangosConfig -from ..utils.docstring_generator import generate_docstring -from .utils.sklearn_base_classifier import SklearnBaseClassifier -from .utils.sklearn_base_lss import SklearnBaseLSS -from .utils.sklearn_base_regressor import SklearnBaseRegressor +from ...base_models.tangos import Tangos +from ...configs.tangos_config import DefaultTangosConfig +from ...utils.docstring_generator import generate_docstring +from ..utils.sklearn_base_classifier import SklearnBaseClassifier +from ..utils.sklearn_base_lss import SklearnBaseLSS +from ..utils.sklearn_base_regressor import SklearnBaseRegressor class TangosRegressor(SklearnBaseRegressor): @@ -14,7 +14,7 @@ class TangosRegressor(SklearnBaseRegressor): with the default Tangos configuration. """, examples=""" - >>> from deeptab.models import TangosRegressor + >>> from deeptab.models.experimental import TangosRegressor >>> model = TangosRegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -34,7 +34,7 @@ class TangosClassifier(SklearnBaseClassifier): with the default Tangos configuration. """, examples=""" - >>> from deeptab.models import TangosClassifier + >>> from deeptab.models.experimental import TangosClassifier >>> model = TangosClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -54,7 +54,7 @@ class TangosLSS(SklearnBaseLSS): with the default Tangos configuration. """, examples=""" - >>> from deeptab.models import TangosLSS + >>> from deeptab.models.experimental import TangosLSS >>> model = TangosLSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family='normal') >>> preds = model.predict(X_test) diff --git a/deeptab/models/trompt.py b/deeptab/models/experimental/trompt.py similarity index 77% rename from deeptab/models/trompt.py rename to deeptab/models/experimental/trompt.py index d827a99..3109cae 100644 --- a/deeptab/models/trompt.py +++ b/deeptab/models/experimental/trompt.py @@ -1,9 +1,9 @@ -from ..base_models.trompt import Trompt -from ..configs.trompt_config import DefaultTromptConfig -from ..utils.docstring_generator import generate_docstring -from .utils.sklearn_base_classifier import SklearnBaseClassifier -from .utils.sklearn_base_lss import SklearnBaseLSS -from .utils.sklearn_base_regressor import SklearnBaseRegressor +from ...base_models.trompt import Trompt +from ...configs.trompt_config import DefaultTromptConfig +from ...utils.docstring_generator import generate_docstring +from ..utils.sklearn_base_classifier import SklearnBaseClassifier +from ..utils.sklearn_base_lss import SklearnBaseLSS +from ..utils.sklearn_base_regressor import SklearnBaseRegressor class TromptRegressor(SklearnBaseRegressor): @@ -15,7 +15,7 @@ class and uses the Trompt model with the default Trompt configuration. """, examples=""" - >>> from deeptab.models import TromptRegressor + >>> from deeptab.models.experimental import TromptRegressor >>> model = TromptRegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -33,7 +33,7 @@ class TromptClassifier(SklearnBaseClassifier): """Trompt Classifier. This class extends the SklearnBaseClassifier class and uses the Trompt model with the default Trompt configuration.""", examples=""" - >>> from deeptab.models import TromptClassifier + >>> from deeptab.models.experimental import TromptClassifier >>> model = TromptClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -52,7 +52,7 @@ class TromptLSS(SklearnBaseLSS): This class extends the SklearnBaseLSS class and uses the Trompt model with the default Trompt configuration.""", examples=""" - >>> from deeptab.models import TromptLSS + >>> from deeptab.models.experimental import TromptLSS >>> model = TromptLSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family="normal") >>> preds = model.predict(X_test) diff --git a/deeptab/models/mambatab.py b/deeptab/models/mambatab.py index d6f98ba..7885a08 100644 --- a/deeptab/models/mambatab.py +++ b/deeptab/models/mambatab.py @@ -13,7 +13,13 @@ class MambaTabRegressor(SklearnBaseRegressor): MambaTab regressor. This class extends the SklearnBaseRegressor class and uses the MambaTab model with the default MambaTab configuration. """, - examples="", + examples=""" + >>> from deeptab.models import MambaTabRegressor + >>> model = MambaTabRegressor(d_model=64, n_layers=2) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -27,7 +33,13 @@ class MambaTabClassifier(SklearnBaseClassifier): MambaTab classifier. This class extends the SklearnBaseClassifier class and uses the MambaTab model with the default MambaTab configuration. """, - examples="", + examples=""" + >>> from deeptab.models import MambaTabClassifier + >>> model = MambaTabClassifier(d_model=64, n_layers=2) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -41,7 +53,13 @@ class MambaTabLSS(SklearnBaseLSS): MambaTab LSS for distributional regression. This class extends the SklearnBaseLSS class and uses the MambaTab model with the default MambaTab configuration. """, - examples="", + examples=""" + >>> from deeptab.models import MambaTabLSS + >>> model = MambaTabLSS(d_model=64, n_layers=2) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): diff --git a/deeptab/models/ndtf.py b/deeptab/models/ndtf.py index dc02d71..ca7b552 100644 --- a/deeptab/models/ndtf.py +++ b/deeptab/models/ndtf.py @@ -13,7 +13,13 @@ class NDTFRegressor(SklearnBaseRegressor): Neural Decision Forest regressor. This class extends the SklearnBaseRegressor class and uses the NDTF model with the default NDTF configuration. """, - examples="", + examples=""" + >>> from deeptab.models import NDTFRegressor + >>> model = NDTFRegressor(n_ensembles=12, max_depth=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -24,10 +30,16 @@ class NDTFClassifier(SklearnBaseClassifier): __doc__ = generate_docstring( DefaultNDTFConfig, model_description=""" - Neural Decision Forest classifier. This class extends the SklearnBasClassifier class and uses the NDTF model + Neural Decision Forest classifier. This class extends the SklearnBaseClassifier class and uses the NDTF model with the default NDTF configuration. """, - examples="", + examples=""" + >>> from deeptab.models import NDTFClassifier + >>> model = NDTFClassifier(n_ensembles=12, max_depth=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -38,10 +50,16 @@ class NDTFLSS(SklearnBaseLSS): __doc__ = generate_docstring( DefaultNDTFConfig, model_description=""" - Neural Decision Forest for distributional regressor. This class extends the SklearnBaseLSS class and uses the NDTF model + Neural Decision Forest for distributional regression. This class extends the SklearnBaseLSS class and uses the NDTF model with the default NDTF configuration. """, - examples="", + examples=""" + >>> from deeptab.models import NDTFLSS + >>> model = NDTFLSS(n_ensembles=12, max_depth=8) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): diff --git a/deeptab/models/tabularnn.py b/deeptab/models/tabularnn.py index a14910a..8febf9e 100644 --- a/deeptab/models/tabularnn.py +++ b/deeptab/models/tabularnn.py @@ -49,85 +49,21 @@ def __init__(self, **kwargs): class TabulaRNNLSS(SklearnBaseLSS): - """RNN LSS. This class extends the SklearnBaseLSS class and uses the TabulaRNN model with the default TabulaRNN - configuration. - - The accepted arguments to the TabulaRNNLSS class include both the attributes in the DefaultTabulaRNNConfig dataclass - and the parameters for the Preprocessor class. - - Parameters - ---------- - lr : float, default=1e-04 - Learning rate for the optimizer. - model_type : str, default="RNN" - type of model, one of "RNN", "LSTM", "GRU" - family : str, default=None - Distributional family to be used for the model. - lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. - weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. - lr_factor : float, default=0.1 - Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the transformer. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() - Activation function for the transformer. - embedding_activation : callable, default=nn.Identity() - Activation function for numerical embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. - head_dropout : float, default=0.5 - Dropout rate for the head layers. - head_skip_layers : bool, default=False - Whether to skip layers in the head. - head_activation : callable, default=nn.SELU() - Activation function for the head layers. - head_use_batch_norm : bool, default=False - Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="cls" - Pooling method to be used ('cls', 'avg', etc.). - norm_first : bool, default=False - Whether to apply normalization before other operations in each transformer block. - bias : bool, default=True - Whether to use bias in the linear layers. - rnn_activation : callable, default=nn.SELU() - Activation function for the transformer layers. - bidirectional : bool, default=False. - Whether to process data bidirectionally - cat_encoding : str, default="int" - Encoding method for categorical features. - n_bins : int, default=50 - The number of bins to use for numerical feature binning. This parameter is relevant - only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. - numerical_preprocessing : str, default="ple" - The preprocessing strategy for numerical features. Valid options are - 'binning', 'one_hot', 'standardization', and 'normalization'. - use_decision_tree_bins : bool, default=False - If True, uses decision tree regression/classification to determine - optimal bin edges for numerical feature binning. This parameter is - relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. - binning_strategy : str, default="uniform" - Defines the strategy for binning numerical features. Options include 'uniform', - 'quantile', or other sklearn-compatible strategies. - cat_cutoff : float or int, default=0.03 - Indicates the cutoff after which integer values are treated as categorical. - If float, it's treated as a percentage. If int, it's the maximum number of - unique values for a column to be considered categorical. - treat_all_integers_as_numerical : bool, default=False - If True, all integer columns will be treated as numerical, regardless - of their unique value count or proportion. - degree : int, default=3 - The degree of the polynomial features to be used in preprocessing. - knots : int, default=12 - The number of knots to be used in spline transformations. - """ + __doc__ = generate_docstring( + DefaultTabulaRNNConfig, + model_description=""" + TabulaRNN for distributional regression. This class extends the SklearnBaseLSS + class and uses the TabulaRNN model with the default TabulaRNN configuration. + Supports RNN, LSTM, GRU, mLSTM, and sLSTM architectures. + """, + examples=""" + >>> from deeptab.models import TabulaRNNLSS + >>> model = TabulaRNNLSS(model_type='LSTM', d_model=128, n_layers=4) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, + ) def __init__(self, **kwargs): super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs) diff --git a/deeptab/models/utils/sklearn_base_lss.py b/deeptab/models/utils/sklearn_base_lss.py index de6e8e9..59b08f9 100644 --- a/deeptab/models/utils/sklearn_base_lss.py +++ b/deeptab/models/utils/sklearn_base_lss.py @@ -38,6 +38,20 @@ StudentTDistribution, ) +DISTRIBUTION_CLASSES = { + "normal": NormalDistribution, + "poisson": PoissonDistribution, + "gamma": GammaDistribution, + "beta": BetaDistribution, + "dirichlet": DirichletDistribution, + "studentt": StudentTDistribution, + "negativebinom": NegativeBinomialDistribution, + "inversegamma": InverseGammaDistribution, + "categorical": CategoricalDistribution, + "quantile": Quantile, + "johnsonsu": JohnsonSuDistribution, +} + class SklearnBaseLSS(BaseEstimator): def __init__(self, model, config, **kwargs): @@ -106,8 +120,10 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} # type: ignore[attr-defined] - params.update(preprocessor_params) + get_params_fn = getattr(self.preprocessor, "get_params", None) + if get_params_fn is not None: + preprocessor_params = {"prepro__" + key: value for key, value in get_params_fn().items()} + params.update(preprocessor_params) return params @@ -381,6 +397,7 @@ def fit( if family in distribution_classes: self.family = distribution_classes[family](**distributional_kwargs) + self.family_name = family else: raise ValueError(f"Unsupported family: {family}") @@ -622,6 +639,138 @@ def encode(self, X, batch_size=64): return encoded_outputs + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Save the fitted model to *path*. + + Parameters + ---------- + path : str + Destination file path (e.g. ``"model.pt"``). + + Raises + ------ + ValueError + If the model has not been fitted yet. + """ + if not getattr(self, "is_fitted_", False): + raise ValueError("Model must be fitted before saving.") + if self.task_model is None: + raise RuntimeError("task_model is unexpectedly None after fitting.") + bundle = { + "_class": type(self), + "config": self.config, + "config_kwargs": self.config_kwargs, + "preprocessor": self.preprocessor, + "feature_info": { + "num": self.data_module.num_feature_info, + "cat": self.data_module.cat_feature_info, + "emb": self.data_module.embedding_feature_info, + }, + "batch_size": self.data_module.batch_size, + "regression": self.data_module.regression, + "model_class": type(self.estimator), + "num_classes": self.task_model.num_classes, + "lss": True, + "family": self.family_name, + "optimizer_type": self.optimizer_type, + "optimizer_kwargs": self.optimizer_kwargs, + "lr": self.task_model.lr, + "lr_patience": self.task_model.lr_patience, + "lr_factor": self.task_model.lr_factor, + "weight_decay": self.task_model.weight_decay, + "task_model_state_dict": self.task_model.state_dict(), + } + torch.save(bundle, path) + + @classmethod + def load(cls, path: str): + """Load and return a fitted model from *path*. + + Parameters + ---------- + path : str + Path to a file previously written by :meth:`save`. + + Returns + ------- + estimator + A fully reconstructed, ready-to-predict estimator. + """ + bundle = torch.load(path, weights_only=False) + + obj = bundle["_class"].__new__(bundle["_class"]) + obj.config = bundle["config"] + obj.config_kwargs = bundle["config_kwargs"] + obj.preprocessor = bundle["preprocessor"] + obj.optimizer_type = bundle["optimizer_type"] + obj.optimizer_kwargs = bundle["optimizer_kwargs"] + obj.built = True + obj.is_fitted_ = True + obj.family = DISTRIBUTION_CLASSES[bundle["family"]]() + obj.family_name = bundle["family"] + obj.preprocessor_arg_names = [ + "n_bins", + "feature_preprocessing", + "numerical_preprocessing", + "categorical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "degree", + "scaling_strategy", + "n_knots", + "use_decision_tree_knots", + "knots_strategy", + "spline_implementation", + ] + + obj.data_module = MambularDataModule( + preprocessor=bundle["preprocessor"], + batch_size=bundle["batch_size"], + shuffle=False, + regression=bundle["regression"], + ) + obj.data_module.num_feature_info = bundle["feature_info"]["num"] + obj.data_module.cat_feature_info = bundle["feature_info"]["cat"] + obj.data_module.embedding_feature_info = bundle["feature_info"]["emb"] + + obj.task_model = TaskModel( + model_class=bundle["model_class"], + config=bundle["config"], + feature_information=( + bundle["feature_info"]["num"], + bundle["feature_info"]["cat"], + bundle["feature_info"]["emb"], + ), + num_classes=bundle["num_classes"], + lss=bundle["lss"], + family=obj.family, + optimizer_type=bundle["optimizer_type"], + optimizer_args=bundle["optimizer_kwargs"], + lr=bundle["lr"], + lr_patience=bundle["lr_patience"], + lr_factor=bundle["lr_factor"], + weight_decay=bundle["weight_decay"], + ) + obj.task_model.load_state_dict(bundle["task_model_state_dict"]) + obj.task_model.eval() + obj.estimator = obj.task_model.estimator + + obj.trainer = pl.Trainer( + max_epochs=1, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + + return obj + def optimize_hparams( self, X, diff --git a/deeptab/models/utils/sklearn_parent.py b/deeptab/models/utils/sklearn_parent.py index 709c723..f738cac 100644 --- a/deeptab/models/utils/sklearn_parent.py +++ b/deeptab/models/utils/sklearn_parent.py @@ -64,12 +64,12 @@ def get_params(self, deep=True): params.update(self.config_kwargs) params.update(self.preprocessor_kwargs) if deep: - preprocessor_params = { - key: value - for key, value in self.preprocessor.get_params().items() # type: ignore[attr-defined] - if key in self.preprocessor_arg_names - } - params.update(preprocessor_params) + get_params_fn = getattr(self.preprocessor, "get_params", None) + if get_params_fn is not None: + preprocessor_params = { + key: value for key, value in get_params_fn().items() if key in self.preprocessor_arg_names + } + params.update(preprocessor_params) return params def set_params(self, **parameters): @@ -478,6 +478,144 @@ def _pretrain( pool_sequence=pool_sequence, ) + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Save the fitted model to *path*. + + The bundle written by this method can be restored with + :meth:`load`. It contains all state required for inference: + the config, the fitted preprocessor, feature metadata, and + the neural-network weights. + + Parameters + ---------- + path : str + Destination file path (e.g. ``"model.pt"``). + + Raises + ------ + ValueError + If the model has not been fitted yet. + """ + if not getattr(self, "is_fitted_", False): + raise ValueError("Model must be fitted before saving.") + if self.task_model is None: + raise RuntimeError("task_model is unexpectedly None after fitting.") + bundle = { + "_class": type(self), + "config": self.config, + "config_kwargs": self.config_kwargs, + "preprocessor_kwargs": getattr(self, "preprocessor_kwargs", {}), + "preprocessor": self.preprocessor, + "feature_info": { + "num": self.data_module.num_feature_info, + "cat": self.data_module.cat_feature_info, + "emb": self.data_module.embedding_feature_info, + }, + "batch_size": self.data_module.batch_size, + "regression": self.data_module.regression, + "model_class": type(self.estimator), + "num_classes": self.task_model.num_classes, + "lss": False, + "family": None, + "optimizer_type": self.optimizer_type, + "optimizer_kwargs": self.optimizer_kwargs, + "lr": self.task_model.lr, + "lr_patience": self.task_model.lr_patience, + "lr_factor": self.task_model.lr_factor, + "weight_decay": self.task_model.weight_decay, + "task_model_state_dict": self.task_model.state_dict(), + } + torch.save(bundle, path) + + @classmethod + def load(cls, path: str): + """Load and return a fitted model from *path*. + + Parameters + ---------- + path : str + Path to a file previously written by :meth:`save`. + + Returns + ------- + estimator + A fully reconstructed, ready-to-predict estimator of the + same type that was saved. + """ + bundle = torch.load(path, weights_only=False) + + obj = bundle["_class"].__new__(bundle["_class"]) + obj.config = bundle["config"] + obj.config_kwargs = bundle["config_kwargs"] + obj.preprocessor_kwargs = bundle.get("preprocessor_kwargs", {}) + obj.preprocessor = bundle["preprocessor"] + obj.optimizer_type = bundle["optimizer_type"] + obj.optimizer_kwargs = bundle["optimizer_kwargs"] + obj.built = True + obj.is_fitted_ = True + obj.preprocessor_arg_names = [ + "n_bins", + "feature_preprocessing", + "numerical_preprocessing", + "categorical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "degree", + "scaling_strategy", + "n_knots", + "use_decision_tree_knots", + "knots_strategy", + "spline_implementation", + ] + + obj.data_module = MambularDataModule( + preprocessor=bundle["preprocessor"], + batch_size=bundle["batch_size"], + shuffle=False, + regression=bundle["regression"], + ) + obj.data_module.num_feature_info = bundle["feature_info"]["num"] + obj.data_module.cat_feature_info = bundle["feature_info"]["cat"] + obj.data_module.embedding_feature_info = bundle["feature_info"]["emb"] + + obj.task_model = TaskModel( + model_class=bundle["model_class"], + config=bundle["config"], + feature_information=( + bundle["feature_info"]["num"], + bundle["feature_info"]["cat"], + bundle["feature_info"]["emb"], + ), + num_classes=bundle["num_classes"], + lss=bundle["lss"], + family=bundle["family"], + optimizer_type=bundle["optimizer_type"], + optimizer_args=bundle["optimizer_kwargs"], + lr=bundle["lr"], + lr_patience=bundle["lr_patience"], + lr_factor=bundle["lr_factor"], + weight_decay=bundle["weight_decay"], + ) + obj.task_model.load_state_dict(bundle["task_model_state_dict"]) + obj.task_model.eval() + obj.estimator = obj.task_model.estimator + + obj.trainer = pl.Trainer( + max_epochs=1, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + + return obj + def optimize_hparams( self, X, diff --git a/docs/api/models/Models.rst b/docs/api/models/Models.rst index fd02d9f..7b64f5e 100644 --- a/docs/api/models/Models.rst +++ b/docs/api/models/Models.rst @@ -169,39 +169,50 @@ deeptab.models :members: :undoc-members: -.. autoclass:: deeptab.models.ModernNCAClassifier + +Experimental Models +------------------- + +.. warning:: + + The classes below live in ``deeptab.models.experimental``. Their API may + change without a deprecation cycle. Import them explicitly:: + + from deeptab.models.experimental import ModernNCAClassifier + +.. autoclass:: deeptab.models.experimental.ModernNCAClassifier :members: :undoc-members: -.. autoclass:: deeptab.models.ModernNCARegressor +.. autoclass:: deeptab.models.experimental.ModernNCARegressor :members: :undoc-members: -.. autoclass:: deeptab.models.ModernNCALSS +.. autoclass:: deeptab.models.experimental.ModernNCALSS :members: :undoc-members: -.. autoclass:: deeptab.models.TangosClassifier +.. autoclass:: deeptab.models.experimental.TangosClassifier :members: :undoc-members: -.. autoclass:: deeptab.models.TangosRegressor +.. autoclass:: deeptab.models.experimental.TangosRegressor :members: :undoc-members: -.. autoclass:: deeptab.models.TangosLSS +.. autoclass:: deeptab.models.experimental.TangosLSS :members: :undoc-members: -.. autoclass:: deeptab.models.TromptClassifier +.. autoclass:: deeptab.models.experimental.TromptClassifier :members: :undoc-members: -.. autoclass:: deeptab.models.TromptRegressor +.. autoclass:: deeptab.models.experimental.TromptRegressor :members: :undoc-members: -.. autoclass:: deeptab.models.TromptLSS +.. autoclass:: deeptab.models.experimental.TromptLSS :members: :undoc-members: diff --git a/docs/api/models/autoint.rst b/docs/api/models/autoint.rst new file mode 100644 index 0000000..07a5e07 --- /dev/null +++ b/docs/api/models/autoint.rst @@ -0,0 +1,46 @@ +AutoInt +======= + +Automatic feature Interaction learning via multi-head self-attention on feature +embeddings. Each input feature is projected into an embedding and the +embeddings are passed through stacked multi-head attention layers. Residual +connections allow the model to combine the original feature representation with +the interaction-augmented representation, making the learned interactions +explicitly additive. + +When to Use +----------- + +When capturing explicit pairwise and higher-order feature interactions is the +primary modelling goal. Historically strong in click-through-rate prediction +and recommendation system benchmarks. + +Limitations +----------- + +- Performance is generally comparable to FTTransformer on most generic tabular + benchmarks; FTTransformer is often a simpler first choice. +- Less effective for very high-dimensional sparse feature spaces compared to + factorisation-machine-based methods. +- The additional residual interaction terms add minor overhead vs plain + Transformer models. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: AutoIntRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: AutoIntClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: AutoIntLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/enode.rst b/docs/api/models/enode.rst new file mode 100644 index 0000000..9d76d46 --- /dev/null +++ b/docs/api/models/enode.rst @@ -0,0 +1,43 @@ +ENODE +===== + +Extended Neural Oblivious Decision Ensembles. ENODE builds on :doc:`node` by +adding explicit feature embedding layers before the decision ensemble. These +embedding layers transform raw input features into richer representations before +they are fed into the differentiable decision trees, improving performance when +the raw feature space is noisy or heterogeneous. + +When to Use +----------- + +Upgrade from NODE when raw feature quality is poor, the data is heterogeneous, +or vanilla NODE underfits. The embedding layers add a small representational +overhead that often pays off on real-world datasets. + +Limitations +----------- + +- Inherits the same fundamental limitations as NODE (high memory, slow training). +- Increased model size compared to plain NODE. +- May be harder to interpret than NODE because the input to the decision + ensemble is no longer the raw feature space. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: ENODERegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: ENODEClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: ENODELSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/fttransformer.rst b/docs/api/models/fttransformer.rst new file mode 100644 index 0000000..460164a --- /dev/null +++ b/docs/api/models/fttransformer.rst @@ -0,0 +1,42 @@ +FTTransformer +============= + +Feature Tokenizer + Transformer. Each input feature — numerical or categorical — +is mapped to a dense token embedding, and the resulting sequence of tokens is +processed through a stack of standard Transformer encoder layers. A ``[CLS]`` +token is prepended and used to produce the final prediction. + +When to Use +----------- + +Strong general-purpose model. Particularly effective on mixed datasets with both +numerical and categorical features where pairwise feature interactions are +important. Typically the first Transformer baseline to try. + +Limitations +----------- + +- Higher memory and compute cost relative to MLP and ResNet. +- Tends to overfit on very small datasets (under ~500 samples); consider adding + dropout or reducing depth. +- Longer training time than simpler architectures. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: FTTransformerRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: FTTransformerClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: FTTransformerLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/index.rst b/docs/api/models/index.rst index 20fe846..864dc51 100644 --- a/docs/api/models/index.rst +++ b/docs/api/models/index.rst @@ -137,7 +137,52 @@ Modules Description :class:`SklearnBaseRegressor` Base class for regression tasks. ======================================= ======================================================================================================= +Experimental Models +------------------- + +.. warning:: + + Experimental models are available from ``deeptab.models.experimental``. + Their API may change without a deprecation cycle. + +.. currentmodule:: deeptab.models.experimental + +======================================= =========================================================================== +Modules Description +======================================= =========================================================================== +:class:`ModernNCAClassifier` ModernNCA for classification tasks. +:class:`ModernNCARegressor` ModernNCA for regression tasks. +:class:`ModernNCALSS` ModernNCA for distributional tasks. +:class:`TangosClassifier` Tangos for classification tasks. +:class:`TangosRegressor` Tangos for regression tasks. +:class:`TangosLSS` Tangos for distributional tasks. +:class:`TromptClassifier` Trompt for classification tasks. +:class:`TromptRegressor` Trompt for regression tasks. +:class:`TromptLSS` Trompt for distributional tasks. +======================================= =========================================================================== + +.. toctree:: + :maxdepth: 1 + :caption: Stable Models + + mlp + resnet + fttransformer + tabtransformer + saint + tabm + tabr + node + ndtf + tabularrnn + mambular + mambatab + mambattention + enode + autoint + .. toctree:: :maxdepth: 1 + :caption: Full API Reference Models diff --git a/docs/api/models/mambatab.rst b/docs/api/models/mambatab.rst new file mode 100644 index 0000000..9eedf78 --- /dev/null +++ b/docs/api/models/mambatab.rst @@ -0,0 +1,41 @@ +MambaTab +======== + +A lightweight Mamba-based architecture that applies a single Mamba SSM block to +a joint representation of all input features. Rather than tokenising each +feature individually, MambaTab concatenates all feature embeddings into one +vector, making it the most computationally efficient model in the Mamba family. + +When to Use +----------- + +Efficiency-focused scenarios where a fast Mamba-based baseline is needed before +scaling to the more expressive :doc:`mambular` architecture. Useful when +training or inference speed is a hard constraint. + +Limitations +----------- + +- The joint input representation loses per-feature granularity compared to + token-level models (FTTransformer, Mambular). +- Less expressive than multi-layer Mambular for complex datasets. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MambaTabRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: MambaTabClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: MambaTabLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/mambattention.rst b/docs/api/models/mambattention.rst new file mode 100644 index 0000000..f648f87 --- /dev/null +++ b/docs/api/models/mambattention.rst @@ -0,0 +1,42 @@ +MambAttention +============= + +Hybrid Mamba + Attention architecture. MambAttention interleaves Mamba SSM +layers with multi-head self-attention layers, allowing the model to capture both +local sequential patterns (via Mamba's linear-time recurrence) and global +dependencies across all features simultaneously (via attention). + +When to Use +----------- + +When you need the memory efficiency of Mamba for local patterns and the +expressiveness of attention for global feature interactions. A natural upgrade +from either :doc:`mambular` or :doc:`fttransformer` when neither alone is +sufficient. + +Limitations +----------- + +- More hyperparameters than either Mambular or FTTransformer alone. +- Higher compute and memory cost than a pure Mamba or pure attention model. +- Fewer community benchmarks available; expect more tuning effort. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MambAttentionRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: MambAttentionClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: MambAttentionLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/mambular.rst b/docs/api/models/mambular.rst new file mode 100644 index 0000000..5dd3998 --- /dev/null +++ b/docs/api/models/mambular.rst @@ -0,0 +1,43 @@ +Mambular +======== + +Sequential Mamba Structured State Space Model (SSM) blocks adapted for tabular +data. Each feature is embedded as a token and the resulting sequence is +processed by stacked Mamba layers, which use efficient linear-time recurrence +rather than quadratic attention. This allows Mambular to scale to longer feature +sequences while keeping memory costs linear. + +When to Use +----------- + +Ordered feature sets or large-scale datasets where Transformer memory costs are +prohibitive. Particularly compelling as an attention-free alternative when the +feature sequence has inherent order (e.g., time-step columns, sensor channels). + +Limitations +----------- + +- Newer architecture with less empirical validation than MLP/ResNet baselines. +- May require more epochs to converge compared to Transformer-based models. +- Performance can be sensitive to the Mamba-specific hyperparameters + (``d_state``, ``expand_factor``). + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MambularRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: MambularClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: MambularLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/mlp.rst b/docs/api/models/mlp.rst new file mode 100644 index 0000000..bfa408a --- /dev/null +++ b/docs/api/models/mlp.rst @@ -0,0 +1,42 @@ +MLP +=== + +A fully-connected feedforward network with configurable depth and width. The +simplest and fastest deep learning baseline for tabular data. Each hidden layer +applies a linear transformation followed by an activation function and optional +dropout. + +When to Use +----------- + +Start here before trying more complex architectures. Works well on most datasets +as a fast, low-cost baseline. Ideal for smaller datasets or when compute budget +is limited. Also useful as a sanity-check model to verify the data pipeline. + +Limitations +----------- + +- Cannot model complex feature interactions without explicit feature engineering. +- May underfit on datasets with strong structural or sequential patterns. +- Performance plateaus with depth due to vanishing gradients (use ResNet if this + is a concern). + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MLPRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: MLPClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: MLPLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/ndtf.rst b/docs/api/models/ndtf.rst new file mode 100644 index 0000000..b8e8a5b --- /dev/null +++ b/docs/api/models/ndtf.rst @@ -0,0 +1,43 @@ +NDTF +==== + +Neural Decision Tree Forest. An ensemble of differentiable soft decision trees +where routing probabilities at each node are learned via sigmoid activations. +A path-probability regularisation term (controlled by ``lamda``) penalises +over-confident or imbalanced routing, encouraging diverse tree usage across the +forest. + +When to Use +----------- + +When interpretability through decision paths is desirable alongside neural +gradient optimisation. Useful as an alternative to NODE when a forest structure +(multiple independent trees) is preferred over oblivious ensembles. + +Limitations +----------- + +- Sensitive to the ``temperature`` and ``lamda`` regularisation hyperparameters. +- Can underfit with too few trees (``n_ensembles``) or overfit with too many. +- Less effective for very high-dimensional data where feature selection at each + split becomes noisy. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: NDTFRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: NDTFClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: NDTFLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/node.rst b/docs/api/models/node.rst new file mode 100644 index 0000000..d011756 --- /dev/null +++ b/docs/api/models/node.rst @@ -0,0 +1,43 @@ +NODE +==== + +Neural Oblivious Decision Ensembles. Each NODE layer is a differentiable +ensemble of oblivious decision trees — trees where the same splitting feature +and threshold is used at every node of a given depth. The trees are made +end-to-end differentiable via entmax transformations, allowing gradient-based +training. + +When to Use +----------- + +When you want the inductive bias of gradient-boosted decision trees inside a +neural framework. Often competitive with gradient boosting on structured tabular +benchmarks while remaining composable as a standard PyTorch layer. + +Limitations +----------- + +- High memory consumption, especially at larger tree depths. +- Slower to train than MLP-based models. +- Sensitive to the ``depth`` hyperparameter; too shallow loses expressiveness, + too deep causes memory and overfitting issues. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: NODERegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: NODEClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: NODELSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/resnet.rst b/docs/api/models/resnet.rst new file mode 100644 index 0000000..67f8398 --- /dev/null +++ b/docs/api/models/resnet.rst @@ -0,0 +1,41 @@ +ResNet +====== + +A deep residual network adapted for tabular data. Skip connections let gradients +flow through deeper stacks without vanishing, enabling more representational +capacity than a plain MLP at the same depth. Each residual block applies two +linear layers with batch normalisation and a skip connection. + +When to Use +----------- + +Choose ResNet when a plain MLP fails to converge well or produces unstable +training curves, or when you need more depth without gradient issues. A good +second step after benchmarking MLP. + +Limitations +----------- + +- More hyperparameters than plain MLP (block size, number of blocks). +- Skip connections add memory overhead. +- May not outperform MLP on small datasets where depth is not beneficial. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: ResNetRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: ResNetClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: ResNetLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/saint.rst b/docs/api/models/saint.rst new file mode 100644 index 0000000..652e7df --- /dev/null +++ b/docs/api/models/saint.rst @@ -0,0 +1,43 @@ +SAINT +===== + +Self-Attention and Intersample Attention Transformer. SAINT augments the +standard column-wise attention of a Transformer with a second attention +mechanism that operates across rows — allowing each sample to attend to other +samples in the batch. This enables the model to leverage inter-sample +relationships during training. + +When to Use +----------- + +When inter-sample relationships are informative, such as in recommendation or +retrieval tasks. Reported strong performance on semi-supervised tabular +benchmarks. Consider SAINT when FTTransformer leaves significant headroom and +more expressive attention is warranted. + +Limitations +----------- + +- Quadratic memory complexity in batch size due to intersample attention. +- Significantly slower than single-sample Transformer models on large batches. +- Gains over simpler models are dataset-dependent; not always worth the extra cost. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: SAINTRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: SAINTClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: SAINTLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/tabm.rst b/docs/api/models/tabm.rst new file mode 100644 index 0000000..d03a7fd --- /dev/null +++ b/docs/api/models/tabm.rst @@ -0,0 +1,41 @@ +TabM +==== + +Batch ensembling applied to an MLP. TabM trains multiple ensemble members that +share most of their weights, with only lightweight per-member scaling factors +making each head distinct. This delivers ensemble-level accuracy at near +single-model memory and compute cost. + +When to Use +----------- + +When you want ensembling diversity without the cost of training multiple +independent models. A strong regularised baseline that often outperforms plain +MLP with minimal extra overhead. + +Limitations +----------- + +- Slightly higher memory footprint than a plain MLP due to the per-member factors. +- The number of ensemble members is an additional hyperparameter to tune. +- Gains diminish beyond a moderate number of members. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabMRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabMClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabMLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/tabr.rst b/docs/api/models/tabr.rst new file mode 100644 index 0000000..779802e --- /dev/null +++ b/docs/api/models/tabr.rst @@ -0,0 +1,43 @@ +TabR +==== + +Retrieval-augmented tabular model. At inference time, TabR retrieves the most +similar training examples from a stored memory of embeddings and uses them as +additional context when computing the prediction. This gives the model access to +local neighbourhood information beyond what is encoded in its weights. + +When to Use +----------- + +Datasets where local similarity structure is informative — rows that are similar +in feature space tend to share similar targets. Effective on low-to-medium-size +datasets where a full nearest-neighbour memory can be maintained affordably. + +Limitations +----------- + +- Inference time scales with training set size as the model must search the + memory store. +- Not suitable for very large datasets (>100 k rows) without approximate + nearest-neighbour indexing. +- Requires keeping the training set in memory during inference. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabRRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabRClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabRLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/tabtransformer.rst b/docs/api/models/tabtransformer.rst new file mode 100644 index 0000000..6bcfdf8 --- /dev/null +++ b/docs/api/models/tabtransformer.rst @@ -0,0 +1,42 @@ +TabTransformer +============== + +Transformer for tabular data with a focus on categorical feature embeddings. +Categorical features are embedded and passed through Transformer encoder layers +to capture inter-categorical dependencies, while numerical features bypass the +attention mechanism and are concatenated at the prediction head. + +When to Use +----------- + +Datasets dominated by high-cardinality categorical features where relationships +between categories are informative. Commonly used in click-through-rate +prediction and entity-heavy tabular problems. + +Limitations +----------- + +- Limited benefit for datasets with mostly numerical features. +- Slower than MLP-based models. +- FTTransformer typically outperforms TabTransformer on mixed datasets because + it tokenises all features uniformly. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabTransformerRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabTransformerClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabTransformerLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/api/models/tabularrnn.rst b/docs/api/models/tabularrnn.rst new file mode 100644 index 0000000..f9cc2b5 --- /dev/null +++ b/docs/api/models/tabularrnn.rst @@ -0,0 +1,44 @@ +TabulaRNN +========= + +Recurrent neural network for tabular data. TabulaRNN treats the feature vector +as a sequence of tokens and processes it with a recurrent cell. The cell type is +configurable: ``RNN``, ``LSTM``, ``GRU``, ``mLSTM`` (matrix LSTM), or +``sLSTM`` (scalar LSTM from the xLSTM family). This makes it a flexible +sequence model that spans classical to modern recurrent architectures. + +When to Use +----------- + +Best suited for datasets where feature ordering encodes meaningful structure — +for example, temporally ordered measurements stored as columns. Also a viable +alternative to Transformer-based models when memory efficiency is a priority. + +Limitations +----------- + +- Performance is sensitive to feature ordering; shuffling columns can + significantly change results. +- May underperform Transformer architectures on unordered tabular data where + positional bias is irrelevant. +- The mLSTM and sLSTM variants are newer and less empirically validated. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabulaRNNRegressor + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabulaRNNClassifier + :members: + :undoc-members: + :noindex: + +.. autoclass:: TabulaRNNLSS + :members: + :undoc-members: + :noindex: diff --git a/docs/developer_guide/model_promotion_policy.md b/docs/developer_guide/model_promotion_policy.md index 2f70f70..7690110 100644 --- a/docs/developer_guide/model_promotion_policy.md +++ b/docs/developer_guide/model_promotion_policy.md @@ -54,16 +54,20 @@ No open GitHub issues labelled `bug` for the model may describe a failure in a c ### 7. Registry -A config class must exist in `deeptab/configs/` and be exported from `deeptab/configs/__init__.py`. The model must be exported from `deeptab/models/__init__.py` and listed in `deeptab/utils/config_mapper.py`. +A config class must exist in `deeptab/configs/` and be exported from `deeptab/configs/__init__.py`. The model must be exported from `deeptab/models/experimental/__init__.py` while experimental, or from `deeptab/models/__init__.py` once stable, and listed in `deeptab/utils/config_mapper.py`. The `MODEL_REGISTRY` in `deeptab/models/_registry.py` must contain an entry with the correct `status` and `import_path`. ## Promotion PR Open a PR titled `feat(): promote to stable`. The PR must: -1. Remove any `.. experimental::` admonition from the model's doc page. -2. Remove any `ExperimentalWarning` raised in `__init__` or `fit`. -3. Remove the experimental badge from the API reference entry. -4. Add the model to the changelog under `### Promoted to Stable`. +1. Move the model file from `deeptab/models/experimental/` to `deeptab/models/` using `git mv`. +2. Update relative imports in the moved file (reduce one `..` level). +3. Remove the model from `deeptab/models/experimental/__init__.py` and its `__all__`. +4. Add the model to `deeptab/models/__init__.py` imports and `__all__`. +5. Update `MODEL_REGISTRY` in `deeptab/models/_registry.py`: change `status` to `"stable"` and `import_path` to `"deeptab.models"`. +6. Remove any `.. experimental::` admonition from the model's doc page. +7. Remove the experimental badge from the API reference entry. +8. Add the model to the changelog under `### Promoted to Stable`. Approval requires at least one maintainer review beyond the author. Use the promotion checklist in the PR template to track each requirement. diff --git a/docs/examples/classification.md b/docs/examples/classification.md index 7c2586b..054dea1 100644 --- a/docs/examples/classification.md +++ b/docs/examples/classification.md @@ -79,6 +79,30 @@ model.fit(X_train, y_train, max_epochs=50) print(model.evaluate(X_test, y_test)) ``` +## All stable classifiers + +Swap `MambularClassifier` for any class below — no other code changes are needed: + +| Class | Architecture | Notes | +| -------------------------- | ------------------------------------- | ------------------------------------ | +| `MLPClassifier` | Feedforward MLP | Fastest baseline | +| `ResNetClassifier` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerClassifier` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerClassifier` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTClassifier` | Self + intersample attention | Good for semi-supervised settings | +| `TabMClassifier` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRClassifier` | Retrieval-augmented | Strong when local similarity matters | +| `NODEClassifier` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFClassifier` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNClassifier` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularClassifier` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabClassifier` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionClassifier` | Mamba + attention hybrid | Local + global patterns | +| `ENODEClassifier` | Extended NODE | NODE with feature embeddings | +| `AutoIntClassifier` | Attention-based interaction | Explicit feature crossing | + +Experimental classifiers (`ModernNCAClassifier`, `TromptClassifier`, `TangosClassifier`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). + ## Next steps - [Key Concepts](../key_concepts) — learn how to tune hyperparameters via config objects. diff --git a/docs/examples/distributional.md b/docs/examples/distributional.md index 31e537a..75af6e0 100644 --- a/docs/examples/distributional.md +++ b/docs/examples/distributional.md @@ -76,6 +76,30 @@ model.fit(X_train, y_train, family="normal", max_epochs=50) print(model.evaluate(X_test, y_test)) ``` +## All stable LSS models + +Swap `MambularLSS` for any class below — pass `family=` to `.fit()` to select the output distribution: + +| Class | Architecture | Notes | +| ------------------- | ------------------------------------- | ------------------------------------ | +| `MLPLSS` | Feedforward MLP | Fastest baseline | +| `ResNetLSS` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerLSS` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerLSS` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTLSS` | Self + intersample attention | Good for semi-supervised settings | +| `TabMLSS` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRLSS` | Retrieval-augmented | Strong when local similarity matters | +| `NODELSS` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFLSS` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNLSS` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularLSS` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabLSS` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionLSS` | Mamba + attention hybrid | Local + global patterns | +| `ENODELSS` | Extended NODE | NODE with feature embeddings | +| `AutoIntLSS` | Attention-based interaction | Explicit feature crossing | + +Experimental LSS models (`ModernNCALSS`, `TromptLSS`, `TangosLSS`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). + ## Next steps - [Key Concepts](../key_concepts) — understand the `LSS` task variant and available distribution families. diff --git a/docs/examples/experimental.md b/docs/examples/experimental.md new file mode 100644 index 0000000..0b0c8c7 --- /dev/null +++ b/docs/examples/experimental.md @@ -0,0 +1,121 @@ +# Using Experimental Models + +Experimental models live in `deeptab.models.experimental`. Their API may change +without a deprecation cycle, but they are otherwise fully functional and follow +the same `fit` / `predict` / `evaluate` interface as stable models. + +```{warning} +Experimental models are not covered by semantic versioning guarantees. +Pin your DeepTab version (`deeptab==x.y.z`) if you use them in production code +to avoid unexpected breakage after upgrades. +``` + +## Import path + +```python +# stable models — imported directly from deeptab.models +from deeptab.models import MambularClassifier + +# experimental models — always import from deeptab.models.experimental +from deeptab.models.experimental import TromptClassifier, ModernNCARegressor, TangosLSS +``` + +Importing an experimental class directly from `deeptab.models` (the old path) +still works but raises a `DeprecationWarning`: + +```python +# raises DeprecationWarning — update the import +from deeptab.models import TromptClassifier +``` + +--- + +## End-to-end example — Trompt for classification + +### Setup + +```python +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +from deeptab.models.experimental import TromptClassifier +``` + +### Generate data + +```python +np.random.seed(42) + +n_samples, n_features, n_classes = 800, 6, 3 +X = np.random.randn(n_samples, n_features) +y = np.random.randint(0, n_classes, size=n_samples) + +df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(n_features)]) +X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=42) +``` + +### Train + +```python +model = TromptClassifier() +model.fit(X_train, y_train, max_epochs=10) +``` + +### Evaluate + +```python +metrics = model.evaluate(X_test, y_test) +print(metrics) +``` + +### Predict + +```python +preds = model.predict(X_test) +proba = model.predict_proba(X_test) +``` + +--- + +## End-to-end example — ModernNCA for regression + +```python +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +from deeptab.models.experimental import ModernNCARegressor + +np.random.seed(0) +n_samples, n_features = 800, 5 +X = np.random.randn(n_samples, n_features) +y = X @ np.random.randn(n_features) + np.random.randn(n_samples) * 0.1 + +df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(n_features)]) +X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0) + +model = ModernNCARegressor(d_model=64, n_layers=4) +model.fit(X_train, y_train, max_epochs=10) + +metrics = model.evaluate(X_test, y_test) +print(metrics) +``` + +--- + +## Switching between experimental and stable + +The API is identical — only the import path changes. When a model is promoted to +stable, update the import and nothing else: + +```python +# Before promotion +from deeptab.models.experimental import TromptClassifier + +# After promotion (no other code changes needed) +from deeptab.models import TromptClassifier +``` + +See [Model Promotion Policy](../developer_guide/model_promotion_policy) for the +criteria a model must meet before it moves to stable. diff --git a/docs/examples/regression.md b/docs/examples/regression.md index b21dea8..48d847e 100644 --- a/docs/examples/regression.md +++ b/docs/examples/regression.md @@ -77,6 +77,30 @@ model.fit(X_train, y_train, max_epochs=50) print(model.evaluate(X_test, y_test)) ``` +## All stable regressors + +Swap `MambularRegressor` for any class below — no other code changes are needed: + +| Class | Architecture | Notes | +| ------------------------- | ------------------------------------- | ------------------------------------ | +| `MLPRegressor` | Feedforward MLP | Fastest baseline | +| `ResNetRegressor` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerRegressor` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerRegressor` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTRegressor` | Self + intersample attention | Good for semi-supervised settings | +| `TabMRegressor` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRRegressor` | Retrieval-augmented | Strong when local similarity matters | +| `NODERegressor` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFRegressor` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNRegressor` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularRegressor` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabRegressor` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionRegressor` | Mamba + attention hybrid | Local + global patterns | +| `ENODERegressor` | Extended NODE | NODE with feature embeddings | +| `AutoIntRegressor` | Attention-based interaction | Explicit feature crossing | + +Experimental regressors (`ModernNCARegressor`, `TromptRegressor`, `TangosRegressor`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). + ## Next steps - [Key Concepts](../key_concepts) — learn how to tune hyperparameters via config objects. diff --git a/docs/index.rst b/docs/index.rst index ff6df57..874df7c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,7 @@ examples/classification examples/regression examples/distributional + examples/experimental .. toctree:: :name: API Reference diff --git a/docs/key_concepts.md b/docs/key_concepts.md index 25e498d..51db1e1 100644 --- a/docs/key_concepts.md +++ b/docs/key_concepts.md @@ -33,6 +33,27 @@ from deeptab.models import MambularRegressor # regression from deeptab.models import MambularLSS # distributional regression ``` +## Stable vs experimental models + +DeepTab ships models at two tiers: + +| Tier | Import path | Guarantee | +| ---------------- | --------------------------------------------- | ------------------------------------------- | +| **Stable** | `from deeptab.models import ...` | Public API frozen under semantic versioning | +| **Experimental** | `from deeptab.models.experimental import ...` | May change without a deprecation cycle | + +Always use the explicit experimental import path to signal that you accept the instability: + +```python +# stable +from deeptab.models import FTTransformerClassifier + +# experimental — explicit path required +from deeptab.models.experimental import TromptClassifier +``` + +See [Using experimental models](examples/experimental) for a full worked example. + ## Configuring hyperparameters Every model has a corresponding config class in `deeptab.configs` that documents all available hyperparameters. You can either pass hyperparameters directly to the constructor or via a config object: diff --git a/docs/overview.md b/docs/overview.md index 7b04f5f..2949dda 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -15,6 +15,8 @@ Tabular data is the most common format in applied machine learning, yet most dee All models support regression, classification, and distributional regression out of the box. Import them as `Regressor`, `Classifier`, or `LSS`. +### Stable + | Model | Architecture | Reference | | ---------------- | ---------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | | `Mambular` | Sequential Mamba (SSM) blocks for tabular data | [Thielmann et al. (2024)](https://arxiv.org/abs/2408.06291) | @@ -32,9 +34,16 @@ All models support regression, classification, and distributional regression out | `TabulaRNN` | Recurrent neural network for tabular data | [Thielmann et al. (2025)](https://arxiv.org/pdf/2411.17207) | | `ENODE` | Extended NODE variant | — | | `AutoInt` | Automatic feature interaction via attention | — | -| `ModernNCA` | Modern neural classification architecture | — | -| `Trompt` | Tabular-specific prompting model | — | -| `TANGOS` | Tabular model with graph-based structure | — | + +### Experimental + +Experimental models are imported from `deeptab.models.experimental`. Their API may change without a deprecation cycle. See [Using experimental models](examples/experimental) for a worked example. + +| Model | Architecture | Reference | +| ----------- | ----------------------------------------- | --------- | +| `ModernNCA` | Modern neural classification architecture | — | +| `Trompt` | Tabular-specific prompting model | — | +| `Tangos` | Tabular model with graph-based structure | — | ## Next steps diff --git a/justfile b/justfile index 65938f3..e3e5aeb 100644 --- a/justfile +++ b/justfile @@ -2,9 +2,10 @@ default: @just --list --unsorted -# install dependencies and set up all pre-commit hooks +# install dependencies, editable package, and set up all pre-commit hooks install: poetry install + poetry run pip install -e . --quiet poetry run pre-commit install --hook-type commit-msg --hook-type pre-commit --hook-type pre-push # update dependencies and pre-commit hook revisions diff --git a/poetry.lock b/poetry.lock index 785cec0..99b04ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -780,6 +780,22 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "delu" +version = "0.0.26" +description = "Deep Learning Utilities for PyTorch users." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "delu-0.0.26-py3-none-any.whl", hash = "sha256:b074f0ad5664b80d70096ead0bd6c69133a9ff883a0f7b341b7570619c91870b"}, + {file = "delu-0.0.26.tar.gz", hash = "sha256:460fb6e1bb8b8fcc5639337486c3d32752820147be7442f1d52bbf93beb919a4"}, +] + +[package.dependencies] +numpy = ">=1.21,<3" +torch = ">=1.9,<3" + [[package]] name = "distlib" version = "0.3.9" @@ -878,6 +894,35 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] +[[package]] +name = "faiss-cpu" +version = "1.13.2" +description = "A library for efficient similarity search and clustering of dense vectors." +optional = false +python-versions = "<3.15,>=3.10" +groups = ["main"] +files = [ + {file = "faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_arm64.whl", hash = "sha256:a9064eb34f8f64438dd5b95c8f03a780b1a3f0b99c46eeacb1f0b5d15fc02dc1"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_x86_64.whl", hash = "sha256:c8d097884521e1ecaea6467aeebbf1aa56ee4a36350b48b2ca6b39366565c317"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ee330a284042c2480f2e90450a10378fd95655d62220159b1408f59ee83ebf1"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ab88ee287c25a119213153d033f7dd64c3ccec466ace267395872f554b648cd7"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:85511129b34f890d19c98b82a0cd5ffb27d89d1cec2ee41d2621ee9f9ef8cf3f"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b32eb4065bac352b52a9f5ae07223567fab0a976c7d05017c01c45a1c24264f"}, + {file = "faiss_cpu-1.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eb8bf5dd96465d043c22195afbe8276d5197b710704290d9b454144a0ad892ed"}, + {file = "faiss_cpu-1.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:7c5944d7807d58fe7244b6aba06be710ee7ed99343365ed92699349efe979f51"}, + {file = "faiss_cpu-1.13.2-cp311-cp311-win_arm64.whl", hash = "sha256:19508a1badfb36e456c1c8664eeb948349f604db5c7545f277a0126b4a84b080"}, + {file = "faiss_cpu-1.13.2-cp312-cp312-win_amd64.whl", hash = "sha256:b82c01d30430dd7b1fa442001b9099735d1a82f6bb72033acdc9206d5ac66a64"}, + {file = "faiss_cpu-1.13.2-cp312-cp312-win_arm64.whl", hash = "sha256:2c4f696ae76e7c97cbc12311db83aaf1e7f4f7be06a3ffea7e5b0e8ec1fd805b"}, + {file = "faiss_cpu-1.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:cb4b5ee184816a4b099162ac93c0d7f0033d81a88e7c1291d0a9cc41ec348984"}, + {file = "faiss_cpu-1.13.2-cp313-cp313-win_arm64.whl", hash = "sha256:1243967eeb2298791ff7f3683a4abd2100d7e6ec7542ca05c3b75d47a7f621e5"}, + {file = "faiss_cpu-1.13.2-cp314-cp314-win_amd64.whl", hash = "sha256:c8b645e7d56591aa35dc75415bb53a62e4a494dba010e16f4b67daeffd830bd7"}, + {file = "faiss_cpu-1.13.2-cp314-cp314-win_arm64.whl", hash = "sha256:8113a2a80b59fe5653cf66f5c0f18be0a691825601a52a614c30beb1fca9bc7c"}, +] + +[package.dependencies] +numpy = ">=1.25.0,<3.0" +packaging = "*" + [[package]] name = "fastjsonschema" version = "2.21.2" @@ -4888,4 +4933,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "83960ab51d000c28565beef5350b987f673689b24e033856dc43d50bdf4284ab" +content-hash = "cd3e8982de0d6cb28767f7bfe3aef250c1254eff982f84a740633aa319bb09b0" diff --git a/pyproject.toml b/pyproject.toml index 0f33720..a611280 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ einops = "^0.8.0" accelerate = "^1.2.1" scipy = "^1.15.0" pretab = "^0.0.1" +delu = "*" +faiss-cpu = "*" [tool.poetry.group.dev.dependencies] pytest = "^8.1" @@ -66,6 +68,22 @@ repository = "https://github.com/OpenTabular/deeptab" package = "https://pypi.org/project/deeptab/" +# test configuration +[tool.pytest.ini_options] +pythonpath = ["."] +filterwarnings = [ + # Lightning trainer noise (dataloader workers, log interval, checkpoint dir, tensorboard) + "ignore::UserWarning:lightning", + # PyTorch transformer / RNN config warnings + "ignore::UserWarning:torch", + # properscoring invalid escape sequences + "ignore::DeprecationWarning:properscoring", + # faiss SwigPy builtin type warnings (emitted during import) + "ignore::DeprecationWarning:importlib", + # delu.nn.Lambda custom-function deprecation + "ignore::DeprecationWarning:delu", +] + # code quality tools [tool.pyright] include = ["deeptab", "tests"] diff --git a/tests/test_model_exports.py b/tests/test_model_exports.py new file mode 100644 index 0000000..4d1e58f --- /dev/null +++ b/tests/test_model_exports.py @@ -0,0 +1,199 @@ +""" +Tests verifying the stable / experimental model export split. + +Stable models must be importable from ``deeptab.models``. +Experimental models must be importable from ``deeptab.models.experimental`` +and must *not* appear in ``deeptab.models.__all__``. +""" + +import importlib + +import pytest + +# --------------------------------------------------------------------------- +# Expected exports +# --------------------------------------------------------------------------- + +STABLE_CLASSES = [ + "AutoIntClassifier", + "AutoIntLSS", + "AutoIntRegressor", + "ENODEClassifier", + "ENODELSS", + "ENODERegressor", + "FTTransformerClassifier", + "FTTransformerLSS", + "FTTransformerRegressor", + "MambAttentionClassifier", + "MambAttentionLSS", + "MambAttentionRegressor", + "MambaTabClassifier", + "MambaTabLSS", + "MambaTabRegressor", + "MambularClassifier", + "MambularLSS", + "MambularRegressor", + "MLPClassifier", + "MLPLSS", + "MLPRegressor", + "NDTFClassifier", + "NDTFLSS", + "NDTFRegressor", + "NODEClassifier", + "NODELSS", + "NODERegressor", + "ResNetClassifier", + "ResNetLSS", + "ResNetRegressor", + "SAINTClassifier", + "SAINTLSS", + "SAINTRegressor", + "TabMClassifier", + "TabMLSS", + "TabMRegressor", + "TabRClassifier", + "TabRLSS", + "TabRRegressor", + "TabTransformerClassifier", + "TabTransformerLSS", + "TabTransformerRegressor", + "TabulaRNNClassifier", + "TabulaRNNLSS", + "TabulaRNNRegressor", +] + +EXPERIMENTAL_CLASSES = [ + "ModernNCAClassifier", + "ModernNCALSS", + "ModernNCARegressor", + "TangosClassifier", + "TangosLSS", + "TangosRegressor", + "TromptClassifier", + "TromptLSS", + "TromptRegressor", +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _import(module_path: str, attr: str): + mod = importlib.import_module(module_path) + return getattr(mod, attr) + + +# --------------------------------------------------------------------------- +# Stable models +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("class_name", STABLE_CLASSES) +def test_stable_model_importable(class_name: str): + """Every stable class is importable from deeptab.models.""" + cls = _import("deeptab.models", class_name) + assert cls is not None + + +@pytest.mark.parametrize("class_name", STABLE_CLASSES) +def test_stable_model_in_all(class_name: str): + """Every stable class is listed in deeptab.models.__all__.""" + import deeptab.models as m + + assert class_name in m.__all__, f"{class_name!r} missing from deeptab.models.__all__" + + +# --------------------------------------------------------------------------- +# Experimental models +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_model_importable(class_name: str): + """Every experimental class is importable from deeptab.models.experimental.""" + cls = _import("deeptab.models.experimental", class_name) + assert cls is not None + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_model_in_experimental_all(class_name: str): + """Every experimental class is listed in deeptab.models.experimental.__all__.""" + import deeptab.models.experimental as exp + + assert class_name in exp.__all__, f"{class_name!r} missing from deeptab.models.experimental.__all__" + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_model_not_in_stable_all(class_name: str): + """Experimental classes must not leak into deeptab.models.__all__.""" + import deeptab.models as m + + assert class_name not in m.__all__, f"{class_name!r} should not be in deeptab.models.__all__" + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + + +def test_registry_stable_import_paths(): + """All stable entries in MODEL_REGISTRY have import_path == 'deeptab.models'.""" + from deeptab.models._registry import MODEL_REGISTRY + + for name, info in MODEL_REGISTRY.items(): + if info.status == "stable": + assert info.import_path == "deeptab.models", ( + f"{name}: expected import_path 'deeptab.models', got {info.import_path!r}" + ) + + +def test_registry_experimental_import_paths(): + """All experimental entries have import_path == 'deeptab.models.experimental'.""" + from deeptab.models._registry import MODEL_REGISTRY + + for name, info in MODEL_REGISTRY.items(): + if info.status == "experimental": + assert info.import_path == "deeptab.models.experimental", ( + f"{name}: expected 'deeptab.models.experimental', got {info.import_path!r}" + ) + + +# --------------------------------------------------------------------------- +# Deprecation warnings +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_via_stable_emits_deprecation(class_name: str): + """Accessing an experimental class via deeptab.models emits DeprecationWarning.""" + import deeptab.models as m + + with pytest.warns(DeprecationWarning, match=class_name): + cls = getattr(m, class_name) + assert cls is not None + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_deprecation_message_contains_new_path(class_name: str): + """Deprecation warning message includes the correct new import path.""" + import warnings + + import deeptab.models as m + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + getattr(m, class_name) + + assert any("deeptab.models.experimental" in str(w.message) for w in caught), ( + f"Expected new path in deprecation message for {class_name!r}" + ) + + +def test_unknown_attribute_raises(): + """Accessing a truly unknown attribute on deeptab.models still raises AttributeError.""" + import deeptab.models as m + + with pytest.raises(AttributeError): + _ = m.ThisDoesNotExist diff --git a/tests/test_models.py b/tests/test_models.py index f0fadf7..fdc4474 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,29 +1,70 @@ """ End-to-end behavioural tests for the sklearn-compatible model API. -Tests cover fit → predict → evaluate for a representative subset of models -(MLP, ResNet, FTTransformer, Mambular) across all three task variants -(Classifier, Regressor, LSS). A small synthetic dataset keeps CI fast. +Tests cover fit → predict → evaluate for all 15 stable models across all +three task variants (Classifier, Regressor, LSS). A small synthetic dataset +keeps CI fast. """ +import platform + import numpy as np import pandas as pd import pytest from sklearn.model_selection import train_test_split from deeptab.models import ( + ENODELSS, MLPLSS, + NDTFLSS, + NODELSS, + SAINTLSS, + AutoIntClassifier, + AutoIntLSS, + AutoIntRegressor, + ENODEClassifier, + ENODERegressor, FTTransformerClassifier, FTTransformerLSS, FTTransformerRegressor, + MambaTabClassifier, + MambaTabLSS, + MambaTabRegressor, + MambAttentionClassifier, + MambAttentionLSS, + MambAttentionRegressor, MambularClassifier, MambularLSS, MambularRegressor, MLPClassifier, MLPRegressor, + NDTFClassifier, + NDTFRegressor, + NODEClassifier, + NODERegressor, ResNetClassifier, ResNetLSS, ResNetRegressor, + SAINTClassifier, + SAINTRegressor, + TabMClassifier, + TabMLSS, + TabMRegressor, + TabRClassifier, + TabRLSS, + TabRRegressor, + TabTransformerClassifier, + TabTransformerLSS, + TabTransformerRegressor, + TabulaRNNClassifier, + TabulaRNNLSS, + TabulaRNNRegressor, +) + +_macos_arm64 = platform.system() == "Darwin" and platform.machine() == "arm64" +_skip_tabr = pytest.mark.skipif( + _macos_arm64, + reason="faiss-cpu from PyPI segfaults on macOS arm64; install via conda for TabR support", ) # --------------------------------------------------------------------------- @@ -56,6 +97,29 @@ def regression_data(): return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) +@pytest.fixture(scope="module") +def classification_data_with_cat(): + """Fixture with one categorical column — required by TabTransformer.""" + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y_cont = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + y = pd.qcut(y_cont, q=N_CLASSES, labels=False) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + df["cat_col"] = rng.choice(["A", "B", "C"], size=N_SAMPLES) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + +@pytest.fixture(scope="module") +def regression_data_with_cat(): + """Fixture with one categorical column — required by TabTransformer.""" + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + df["cat_col"] = rng.choice(["A", "B", "C"], size=N_SAMPLES) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + # --------------------------------------------------------------------------- # Classifier tests # --------------------------------------------------------------------------- @@ -65,6 +129,16 @@ def regression_data(): ResNetClassifier, FTTransformerClassifier, MambularClassifier, + TabMClassifier, + pytest.param(TabRClassifier, marks=_skip_tabr), + NODEClassifier, + NDTFClassifier, + SAINTClassifier, + AutoIntClassifier, + MambaTabClassifier, + MambAttentionClassifier, + TabulaRNNClassifier, + ENODEClassifier, ] @@ -115,6 +189,16 @@ def test_classifier_evaluate_returns_dict(cls, classification_data): ResNetRegressor, FTTransformerRegressor, MambularRegressor, + TabMRegressor, + pytest.param(TabRRegressor, marks=_skip_tabr), + NODERegressor, + NDTFRegressor, + SAINTRegressor, + AutoIntRegressor, + MambaTabRegressor, + MambAttentionRegressor, + TabulaRNNRegressor, + ENODERegressor, ] @@ -149,6 +233,16 @@ def test_regressor_evaluate_returns_dict(cls, regression_data): ResNetLSS, FTTransformerLSS, MambularLSS, + TabMLSS, + pytest.param(TabRLSS, marks=_skip_tabr), + NODELSS, + NDTFLSS, + SAINTLSS, + AutoIntLSS, + MambaTabLSS, + MambAttentionLSS, + TabulaRNNLSS, + ENODELSS, ] @@ -173,3 +267,57 @@ def test_lss_evaluate_returns_dict(cls, regression_data): metrics = model.evaluate(X_test, y_test) assert isinstance(metrics, dict), f"{cls.__name__}.evaluate should return a dict" assert len(metrics) > 0, f"{cls.__name__}.evaluate returned an empty dict" + + +# --------------------------------------------------------------------------- +# Config serialisation round-trip (Requirement 5) +# --------------------------------------------------------------------------- + +ALL_ESTIMATOR_CLASSES = ( + CLASSIFIERS + REGRESSORS + LSS_MODELS + [TabTransformerClassifier, TabTransformerRegressor, TabTransformerLSS] +) + + +@pytest.mark.parametrize("cls", ALL_ESTIMATOR_CLASSES) +def test_config_serialisation_roundtrip(cls): + """get_params() → construct new instance → config values survive.""" + model = cls() + params = model.get_params() + + # Constructing a second instance with the same params must not raise. + model2 = cls(**params) + + # All config kwargs must round-trip exactly. + for key, value in model.config_kwargs.items(): + assert getattr(model2.config, key, object()) == value, ( + f"{cls.__name__}: config.{key}={value!r} did not survive get_params round-trip" + ) + + +# --------------------------------------------------------------------------- +# TabTransformer — requires at least one categorical feature +# --------------------------------------------------------------------------- + +TAB_TRANSFORMER_MODELS = [ + (TabTransformerClassifier, "classification"), + (TabTransformerRegressor, "regression"), + (TabTransformerLSS, "lss"), +] + + +@pytest.mark.parametrize("cls,task", TAB_TRANSFORMER_MODELS) +def test_tabtransformer_fit_predict(cls, task, classification_data_with_cat, regression_data_with_cat): + if task == "classification": + X_train, X_test, y_train, _y_test = classification_data_with_cat + else: + X_train, X_test, y_train, _y_test = regression_data_with_cat + + model = cls() + if task == "lss": + model.fit(X_train, y_train, family="normal", **FIT_KWARGS) + else: + model.fit(X_train, y_train, **FIT_KWARGS) + + preds = model.predict(X_test) + assert preds.shape[0] == len(X_test), f"{cls.__name__}.predict returned unexpected shape" + assert np.isfinite(preds).all(), f"{cls.__name__}.predict returned non-finite values" diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 0000000..07db1df --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,153 @@ +""" +Round-trip save / load tests — Requirement 4. + +For each task type (Regressor, Classifier, LSS) we: + 1. Fit a small model. + 2. Record predictions on a held-out set. + 3. Save the model to a temporary file. + 4. Load it back into a fresh object. + 5. Assert the predictions are bit-for-bit identical. + +We use the lightest available model (MLP) to keep CI fast. +""" + +import os +import tempfile +from typing import Any + +import numpy as np +import pandas as pd +import pytest +from sklearn.model_selection import train_test_split + +from deeptab.models import MLPLSS, MLPClassifier, MLPRegressor + +# --------------------------------------------------------------------------- +# Shared dataset parameters +# --------------------------------------------------------------------------- + +N_SAMPLES = 200 +N_FEATURES = 6 +N_CLASSES = 3 +RANDOM_STATE = 7 +FIT_KWARGS: dict[str, Any] = {"max_epochs": 2, "batch_size": 64} + + +@pytest.fixture(scope="module") +def regression_data(): + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + +@pytest.fixture(scope="module") +def classification_data(): + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y_cont = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + y = pd.qcut(y_cont, q=N_CLASSES, labels=False) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + +# --------------------------------------------------------------------------- +# Regressor round-trip +# --------------------------------------------------------------------------- + + +def test_regressor_save_load_predictions(regression_data): + X_train, X_test, y_train, _y_test = regression_data + model = MLPRegressor() + model.fit(X_train, y_train, **FIT_KWARGS) + + preds_before = model.predict(X_test) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tmp_path = f.name + try: + model.save(tmp_path) + loaded = MLPRegressor.load(tmp_path) + finally: + os.unlink(tmp_path) + + preds_after = loaded.predict(X_test) + + np.testing.assert_array_equal( + preds_before, + preds_after, + err_msg="MLPRegressor predictions changed after save/load round-trip", + ) + + +def test_regressor_save_raises_when_unfitted(): + model = MLPRegressor() + with pytest.raises(ValueError, match="fitted"): + with tempfile.NamedTemporaryFile(suffix=".pt") as f: + model.save(f.name) + + +# --------------------------------------------------------------------------- +# Classifier round-trip +# --------------------------------------------------------------------------- + + +def test_classifier_save_load_predictions(classification_data): + X_train, X_test, y_train, _y_test = classification_data + model = MLPClassifier() + model.fit(X_train, y_train, **FIT_KWARGS) + + preds_before = model.predict(X_test) + proba_before = model.predict_proba(X_test) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tmp_path = f.name + try: + model.save(tmp_path) + loaded = MLPClassifier.load(tmp_path) + finally: + os.unlink(tmp_path) + + preds_after = loaded.predict(X_test) + proba_after = loaded.predict_proba(X_test) + + np.testing.assert_array_equal( + preds_before, + preds_after, + err_msg="MLPClassifier.predict changed after save/load round-trip", + ) + np.testing.assert_array_equal( + proba_before, + proba_after, + err_msg="MLPClassifier.predict_proba changed after save/load round-trip", + ) + + +# --------------------------------------------------------------------------- +# LSS round-trip +# --------------------------------------------------------------------------- + + +def test_lss_save_load_predictions(regression_data): + X_train, X_test, y_train, _y_test = regression_data + model = MLPLSS() + model.fit(X_train, y_train, family="normal", **FIT_KWARGS) + + preds_before = model.predict(X_test) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tmp_path = f.name + try: + model.save(tmp_path) + loaded = MLPLSS.load(tmp_path) + finally: + os.unlink(tmp_path) + + preds_after = loaded.predict(X_test) + + np.testing.assert_array_equal( + preds_before, + preds_after, + err_msg="MLPLSS predictions changed after save/load round-trip", + )