Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
efe643d
feat: stable and experimental model separation
mkumar73 May 14, 2026
c592731
fix: resolve stale site-packages
mkumar73 May 14, 2026
ad23aaf
fix: import error after experimental namespace changes
mkumar73 May 14, 2026
a993829
docs: docstring added to the models ndtf, tabularnn and mambatab
mkumar73 May 14, 2026
82ba57a
docs: update docs for stable/experimental model split
mkumar73 May 14, 2026
34c9188
docs: add individual model pages, stable/experimental split tables
mkumar73 May 14, 2026
0ceb06f
fix: DefaultTabRConfig export
mkumar73 May 14, 2026
3a616ce
docs: tables created for examples
mkumar73 May 14, 2026
932ea31
fix: add delu dependency for TabR
mkumar73 May 14, 2026
fc891b4
fix: added faiss-cpu
mkumar73 May 14, 2026
1867178
fix(ndtf): correct ensemble aggregation for multi-class and LSS outputs
mkumar73 May 14, 2026
45b4b26
fix(tabtrans): add dedicated LayerNorm for numerical features instead…
mkumar73 May 14, 2026
97bed3e
feat(models): add save/load, fix get_params and LSS family pickling
mkumar73 May 14, 2026
a01a2e2
test: full 15-model coverage with save/load, config round-trip, and T…
mkumar73 May 14, 2026
97e0f1e
fix(lint): drop unused unpacked variables and dead code
mkumar73 May 14, 2026
f666805
fix(types): guard task_model None, use getattr for get_params, annota…
mkumar73 May 14, 2026
0b29e43
fix(tabr): use regression label encoder for LSS by forwarding lss flag
mkumar73 May 14, 2026
971a2fb
fix(tabr): cast candidate_y to float in regression/LSS label encoder …
mkumar73 May 14, 2026
9180fb2
fix: suppress family hparam warning
mkumar73 May 14, 2026
a0f70f4
chore: suppress tool warnings
mkumar73 May 14, 2026
f1fc937
fix: test case fixed for windows
mkumar73 May 14, 2026
b081d4b
fix: duplicate entry in docs for model classes
mkumar73 May 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deeptab/arch_utils/layer_utils/batch_ensemble_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor:
"""
# Check input shape and expand if necessary
if x.dim() == 3: # Case: (B, L, D) - no ensembles
batch_size, seq_len, input_size = x.shape
batch_size, seq_len, _ = x.shape
# Shape: (B, L, ensemble_size, D)
x = x.unsqueeze(2).expand(-1, -1, self.ensemble_size, -1)
elif x.dim() == 4 and x.size(2) == self.ensemble_size: # Case: (B, L, ensemble_size, D)
Expand Down Expand Up @@ -451,7 +451,7 @@ def forward(self, query, key, value, mask=None):
If the ensemble size `E` does not match `self.ensemble_size`.
"""

N, S, E, D = query.size()
N, S, E, _ = query.size()
if E != self.ensemble_size:
raise ValueError("Ensemble size mismatch.")

Expand Down
4 changes: 2 additions & 2 deletions deeptab/arch_utils/lstm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(self, x):
"""
if x.ndim != 3:
raise ValueError("Input tensor must have 3 dimensions (batch, sequence_length, input_size)")
B, N, D = x.shape
B, N, _ = x.shape
device = x.device

# Initialize states dynamically based on input shape
Expand Down Expand Up @@ -293,7 +293,7 @@ def forward(self, x):
torch.Tensor
Output tensor of shape (batch, sequence_length, input_size).
"""
B, N, D = x.shape
B, N, _ = x.shape
device = x.device

# Initialize states dynamically based on input shape
Expand Down
2 changes: 1 addition & 1 deletion deeptab/arch_utils/mamba_utils/mamba_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def forward(self, x):

if self.bidirectional:
xz_bwd = self.in_proj_bwd(x)
x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1)
x_bwd, _ = xz_bwd.chunk(2, dim=-1)

x_bwd = x_bwd.transpose(1, 2)
x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L]
Expand Down
5 changes: 2 additions & 3 deletions deeptab/arch_utils/transformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,10 @@ def forward(self, x, mask: torch.Tensor = None): # type: ignore
The output tensor of shape (N, S, E, D).
"""
if x.dim() == 3: # Case: (B, L, D) - no ensembles
batch_size, seq_len, input_size = x.shape
# Shape: (B, L, ensemble_size, D)
x = x.unsqueeze(2).expand(-1, -1, self.ensemble_size, -1)
elif x.dim() == 4 and x.size(2) == self.ensemble_size: # Case: (B, L, ensemble_size, D)
batch_size, seq_len, ensemble_size, _ = x.shape
_, _, ensemble_size, _ = x.shape
if ensemble_size != self.ensemble_size:
raise ValueError(f"Input shape {x.shape} is invalid. Expected shape: (B, S, ensemble_size, N)")
else:
Expand Down Expand Up @@ -425,7 +424,7 @@ def forward(self, x):
x: Input embeddings of shape (N, J, D),
where N = batch size, J = number of features, D = embedding dimension.
"""
_, n, d = x.shape
_, n, _ = x.shape

for attn1, ff1, attn2, ff2 in self.layers: # type: ignore
# Column-wise attention
Expand Down
3 changes: 0 additions & 3 deletions deeptab/base_models/autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ def forward(self, *data):
# Apply normalization before attention if prenormalization is enabled
x_residual = layer["norm0"](x_residual) # type: ignore[index]

# Retrieve key-value compression layers
key_compression, value_compression = self._get_kv_compressions(layer)

# Multihead Attention
x_residual, _ = layer["attention"](x_residual, x_residual, x_residual) # type: ignore[index]

Expand Down
11 changes: 6 additions & 5 deletions deeptab/base_models/ndtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def forward(self, *data) -> torch.Tensor:
tree_input = x[:, : self.input_dimensions[idx]]
preds.append(tree(tree_input, return_penalty=False))

preds = torch.stack(preds, dim=1).squeeze(-1)

return preds @ self.tree_weights
preds = torch.stack(preds, dim=1) # (batch, n_ensembles, output_dim)
# Weighted sum over ensemble dim: (batch, output_dim, n_ensembles) @ (n_ensembles, 1)
return (preds.transpose(1, 2) @ self.tree_weights).squeeze(-1)

def penalty_forward(self, *data) -> torch.Tensor:
"""Forward pass of the NDTF model.
Expand Down Expand Up @@ -158,5 +158,6 @@ def penalty_forward(self, *data) -> torch.Tensor:
penalty += pen

# Stack predictions and calculate mean across trees
preds = torch.stack(preds, dim=1).squeeze(-1)
return preds @ self.tree_weights, self.hparams.penalty_factor * penalty # type: ignore
preds = torch.stack(preds, dim=1) # (batch, n_ensembles, output_dim)
# Weighted sum over ensemble dim: (batch, output_dim, n_ensembles) @ (n_ensembles, 1)
return (preds.transpose(1, 2) @ self.tree_weights).squeeze(-1), self.hparams.penalty_factor * penalty # type: ignore
29 changes: 14 additions & 15 deletions deeptab/base_models/tabr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def __init__(
self,
feature_information: tuple,
num_classes=1,
lss: bool = False,
config: DefaultTabRConfig = DefaultTabRConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
super().__init__(config=config, lss=lss, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])

# lazy import
Expand Down Expand Up @@ -94,7 +95,7 @@ def make_block(prenorm: bool) -> nn.Sequential:
delu = TabR.delu
self.label_encoder = (
nn.Linear(1, d_main)
if num_classes == 1
if num_classes == 1 or lss
else nn.Sequential(
nn.Embedding(num_classes, d_main),
# gives depreciation warning
Expand Down Expand Up @@ -278,10 +279,10 @@ def train_with_candidates(self, *data, targets, candidate_x, candidate_y):
probs = F.softmax(similarities, dim=-1)
probs = self.dropout(probs)

if self.hparams.num_classes > 1: # for classification
if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
else: # for regression
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
else: # for regression or LSS
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float())
if len(context_y_emb.shape) == 4:
context_y_emb = context_y_emb[:, :, 0, :]

Expand Down Expand Up @@ -324,7 +325,7 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y):
candidate_x, candidate_k = self._encode(candidate_x)

x, k = self._encode(x) # encoded x and k
batch_size, d_main = k.shape
_, d_main = k.shape
device = k.device
context_size = self.context_size

Expand All @@ -338,9 +339,8 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y):
# Updating the index is much faster than creating a new one.
self.search_index.reset()
self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
distances: Tensor
context_idx: Tensor
distances, context_idx = self.search_index.search( # type: ignore[code]
_, context_idx = self.search_index.search( # type: ignore[code]
k.to(torch.float32), context_size
)

Expand All @@ -353,10 +353,10 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y):
probs = F.softmax(similarities, dim=-1)
probs = self.dropout(probs)

if self.hparams.num_classes > 1: # for classification
if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
else: # for regression
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float())
if len(context_y_emb.shape) == 4:
context_y_emb = context_y_emb[:, :, 0, :]

Expand Down Expand Up @@ -398,7 +398,7 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y):
candidate_x, candidate_k = self._encode(candidate_x)

x, k = self._encode(x) # encoded x and k
batch_size, d_main = k.shape
_, d_main = k.shape
device = k.device
context_size = self.context_size

Expand All @@ -412,9 +412,8 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y):
# Updating the index is much faster than creating a new one.
self.search_index.reset()
self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
distances: Tensor
context_idx: Tensor
distances, context_idx = self.search_index.search( # type: ignore[code]
_, context_idx = self.search_index.search( # type: ignore[code]
k.to(torch.float32), context_size
)

Expand All @@ -427,10 +426,10 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y):
probs = F.softmax(similarities, dim=-1)
probs = self.dropout(probs)

if self.hparams.num_classes > 1: # for classification
if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
else: # for regression
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float())
if len(context_y_emb.shape) == 4:
context_y_emb = context_y_emb[:, :, 0, :]

Expand Down
7 changes: 5 additions & 2 deletions deeptab/base_models/tabtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def __init__(
mlp_input_dim = 0
for feature_name, info in num_feature_info.items():
mlp_input_dim += info["dimension"]
num_input_dim = mlp_input_dim # save before adding d_model
mlp_input_dim += self.hparams.d_model

self.num_norm = nn.LayerNorm(num_input_dim)

self.tabular_head = MLPhead(
input_dim=mlp_input_dim,
config=config,
Expand Down Expand Up @@ -125,13 +128,13 @@ def forward(self, *data):
cat_embeddings = self.embedding_layer(*(None, cat_features, emb_features))

num_features = torch.cat(num_features, dim=1)
num_embeddings = self.norm_f(num_features) # type: ignore
num_features = self.num_norm(num_features)

x = self.encoder(cat_embeddings)

x = self.pool_sequence(x)

x = torch.cat((x, num_embeddings), axis=1) # type: ignore
x = torch.cat((x, num_features), axis=1) # type: ignore
preds = self.tabular_head(x)

return preds
2 changes: 1 addition & 1 deletion deeptab/base_models/utils/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def pool_sequence(self, out):
return out[:, 0, :]
elif self.hparams.pooling_method == "learned_flatten":
# Flatten sequence and apply a learned linear layer
batch_size, seq_len, hidden_size = out.shape
batch_size, _, _ = out.shape
# Shape: (batch_size, seq_len * hidden_size)
out = out.reshape(batch_size, -1)
# Shape: (batch_size, hidden_size)
Expand Down
7 changes: 4 additions & 3 deletions deeptab/base_models/utils/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
else:
self.loss_fct = nn.MSELoss()

self.save_hyperparameters(ignore=["model_class", "loss_fn"])
self.save_hyperparameters(ignore=["model_class", "loss_fn", "family"])

self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
Expand All @@ -93,6 +93,7 @@ def __init__(
config=config,
feature_information=feature_information,
num_classes=output_dim,
lss=lss,
**kwargs,
)

Expand Down Expand Up @@ -176,7 +177,7 @@ def compute_loss(self, predictions, y_true):
if getattr(self.estimator, "returns_ensemble", False): # Ensemble case
if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3:
# Classification case with ensemble: predictions (N, E, k), y_true (N,)
N, E, k = predictions.shape
_, E, _ = predictions.shape
loss = 0.0
for ensemble_member in range(E):
loss += self.loss_fct(
Expand Down Expand Up @@ -595,7 +596,7 @@ def contrastive_loss(self, embeddings, knn_indices, temperature=0.1):
Tensor
Contrastive loss value.
"""
N, S, D = embeddings.shape # Batch size, sequence length, embedding dim
_, S, D = embeddings.shape # Batch size, sequence length, embedding dim
k_neighbors = knn_indices.shape[1] # Number of neighbors

# Normalize embeddings
Expand Down
2 changes: 2 additions & 0 deletions deeptab/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .resnet_config import DefaultResNetConfig
from .saint_config import DefaultSAINTConfig
from .tabm_config import DefaultTabMConfig
from .tabr_config import DefaultTabRConfig
from .tabtransformer_config import DefaultTabTransformerConfig
from .tabularnn_config import DefaultTabulaRNNConfig
from .tangos_config import DefaultTangosConfig
Expand All @@ -32,6 +33,7 @@
"DefaultResNetConfig",
"DefaultSAINTConfig",
"DefaultTabMConfig",
"DefaultTabRConfig",
"DefaultTabTransformerConfig",
"DefaultTabulaRNNConfig",
"DefaultTangosConfig",
Expand Down
48 changes: 36 additions & 12 deletions deeptab/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import importlib
import warnings

from .autoint import AutoIntClassifier, AutoIntLSS, AutoIntRegressor
from .enode import ENODELSS, ENODEClassifier, ENODERegressor
from .fttransformer import (
Expand All @@ -13,20 +16,18 @@
)
from .mambular import MambularClassifier, MambularLSS, MambularRegressor
from .mlp import MLPLSS, MLPClassifier, MLPRegressor
from .modern_nca import ModernNCAClassifier, ModernNCALSS, ModernNCARegressor
from .ndtf import NDTFLSS, NDTFClassifier, NDTFRegressor
from .node import NODELSS, NODEClassifier, NODERegressor
from .resnet import ResNetClassifier, ResNetLSS, ResNetRegressor
from .saint import SAINTLSS, SAINTClassifier, SAINTRegressor
from .tabm import TabMClassifier, TabMLSS, TabMRegressor
from .tabr import TabRClassifier, TabRLSS, TabRRegressor
from .tabtransformer import (
TabTransformerClassifier,
TabTransformerLSS,
TabTransformerRegressor,
)
from .tabularnn import TabulaRNNClassifier, TabulaRNNLSS, TabulaRNNRegressor
from .tangos import TangosClassifier, TangosLSS, TangosRegressor
from .trompt import TromptClassifier, TromptLSS, TromptRegressor
from .utils.sklearn_base_classifier import SklearnBaseClassifier
from .utils.sklearn_base_lss import SklearnBaseLSS
from .utils.sklearn_base_regressor import SklearnBaseRegressor
Expand Down Expand Up @@ -56,9 +57,6 @@
"MambularClassifier",
"MambularLSS",
"MambularRegressor",
"ModernNCAClassifier",
"ModernNCALSS",
"ModernNCARegressor",
"NDTFClassifier",
"NDTFRegressor",
"NODEClassifier",
Expand All @@ -74,16 +72,42 @@
"TabMClassifier",
"TabMLSS",
"TabMRegressor",
"TabRClassifier",
"TabRLSS",
"TabRRegressor",
"TabTransformerClassifier",
"TabTransformerLSS",
"TabTransformerRegressor",
"TabulaRNNClassifier",
"TabulaRNNLSS",
"TabulaRNNRegressor",
"TangosClassifier",
"TangosLSS",
"TangosRegressor",
"TromptClassifier",
"TromptLSS",
"TromptRegressor",
]

# ---------------------------------------------------------------------------
# Backwards-compatibility shim for experimental models
# ---------------------------------------------------------------------------

_EXPERIMENTAL_COMPAT: dict[str, str] = {
"ModernNCAClassifier": "deeptab.models.experimental",
"ModernNCALSS": "deeptab.models.experimental",
"ModernNCARegressor": "deeptab.models.experimental",
"TangosClassifier": "deeptab.models.experimental",
"TangosLSS": "deeptab.models.experimental",
"TangosRegressor": "deeptab.models.experimental",
"TromptClassifier": "deeptab.models.experimental",
"TromptLSS": "deeptab.models.experimental",
"TromptRegressor": "deeptab.models.experimental",
}


def __getattr__(name: str):
if name in _EXPERIMENTAL_COMPAT:
new_path = _EXPERIMENTAL_COMPAT[name]
warnings.warn(
f"{name!r} has moved to '{new_path}'. Update your import: from {new_path} import {name}",
DeprecationWarning,
stacklevel=2,
)
mod = importlib.import_module(new_path)
return getattr(mod, name)
raise AttributeError(f"module 'deeptab.models' has no attribute {name!r}")
Loading
Loading