From efe643d93be218a75c56f9cddf1237a733a3aba0 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 14:48:56 +0200 Subject: [PATCH 01/22] feat: stable and experimental model separation --- deeptab/models/__init__.py | 16 +- deeptab/models/_registry.py | 33 ++++ deeptab/models/experimental/__init__.py | 23 +++ .../models/{ => experimental}/modern_nca.py | 12 +- deeptab/models/{ => experimental}/tangos.py | 12 +- deeptab/models/{ => experimental}/trompt.py | 12 +- tests/test_model_exports.py | 160 ++++++++++++++++++ 7 files changed, 238 insertions(+), 30 deletions(-) create mode 100644 deeptab/models/_registry.py create mode 100644 deeptab/models/experimental/__init__.py rename deeptab/models/{ => experimental}/modern_nca.py (86%) rename deeptab/models/{ => experimental}/tangos.py (86%) rename deeptab/models/{ => experimental}/trompt.py (85%) create mode 100644 tests/test_model_exports.py diff --git a/deeptab/models/__init__.py b/deeptab/models/__init__.py index ebffddc..b23d90b 100644 --- a/deeptab/models/__init__.py +++ b/deeptab/models/__init__.py @@ -13,20 +13,18 @@ ) from .mambular import MambularClassifier, MambularLSS, MambularRegressor from .mlp import MLPLSS, MLPClassifier, MLPRegressor -from .modern_nca import ModernNCAClassifier, ModernNCALSS, ModernNCARegressor from .ndtf import NDTFLSS, NDTFClassifier, NDTFRegressor from .node import NODELSS, NODEClassifier, NODERegressor from .resnet import ResNetClassifier, ResNetLSS, ResNetRegressor from .saint import SAINTLSS, SAINTClassifier, SAINTRegressor from .tabm import TabMClassifier, TabMLSS, TabMRegressor +from .tabr import TabRClassifier, TabRLSS, TabRRegressor from .tabtransformer import ( TabTransformerClassifier, TabTransformerLSS, TabTransformerRegressor, ) from .tabularnn import TabulaRNNClassifier, TabulaRNNLSS, TabulaRNNRegressor -from .tangos import TangosClassifier, TangosLSS, TangosRegressor -from .trompt import TromptClassifier, TromptLSS, TromptRegressor from .utils.sklearn_base_classifier import SklearnBaseClassifier from .utils.sklearn_base_lss import SklearnBaseLSS from .utils.sklearn_base_regressor import SklearnBaseRegressor @@ -56,9 +54,6 @@ "MambularClassifier", "MambularLSS", "MambularRegressor", - "ModernNCAClassifier", - "ModernNCALSS", - "ModernNCARegressor", "NDTFClassifier", "NDTFRegressor", "NODEClassifier", @@ -74,16 +69,13 @@ "TabMClassifier", "TabMLSS", "TabMRegressor", + "TabRClassifier", + "TabRLSS", + "TabRRegressor", "TabTransformerClassifier", "TabTransformerLSS", "TabTransformerRegressor", "TabulaRNNClassifier", "TabulaRNNLSS", "TabulaRNNRegressor", - "TangosClassifier", - "TangosLSS", - "TangosRegressor", - "TromptClassifier", - "TromptLSS", - "TromptRegressor", ] diff --git a/deeptab/models/_registry.py b/deeptab/models/_registry.py new file mode 100644 index 0000000..df2703c --- /dev/null +++ b/deeptab/models/_registry.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import Literal + +ModelStatus = Literal["stable", "experimental"] + + +@dataclass(frozen=True) +class ModelInfo: + name: str + status: ModelStatus + import_path: str + + +MODEL_REGISTRY: dict[str, ModelInfo] = { + "Mambular": ModelInfo("Mambular", "stable", "deeptab.models"), + "TabM": ModelInfo("TabM", "stable", "deeptab.models"), + "NODE": ModelInfo("NODE", "stable", "deeptab.models"), + "ENODE": ModelInfo("ENODE", "stable", "deeptab.models"), + "FTTransformer": ModelInfo("FTTransformer", "stable", "deeptab.models"), + "MLP": ModelInfo("MLP", "stable", "deeptab.models"), + "ResNet": ModelInfo("ResNet", "stable", "deeptab.models"), + "TabTransformer": ModelInfo("TabTransformer", "stable", "deeptab.models"), + "MambaTab": ModelInfo("MambaTab", "stable", "deeptab.models"), + "TabulaRNN": ModelInfo("TabulaRNN", "stable", "deeptab.models"), + "MambAttention": ModelInfo("MambAttention", "stable", "deeptab.models"), + "NDTF": ModelInfo("NDTF", "stable", "deeptab.models"), + "SAINT": ModelInfo("SAINT", "stable", "deeptab.models"), + "AutoInt": ModelInfo("AutoInt", "stable", "deeptab.models"), + "TabR": ModelInfo("TabR", "stable", "deeptab.models"), + "ModernNCA": ModelInfo("ModernNCA", "experimental", "deeptab.models.experimental"), + "Tangos": ModelInfo("Tangos", "experimental", "deeptab.models.experimental"), + "Trompt": ModelInfo("Trompt", "experimental", "deeptab.models.experimental"), +} diff --git a/deeptab/models/experimental/__init__.py b/deeptab/models/experimental/__init__.py new file mode 100644 index 0000000..59ad46b --- /dev/null +++ b/deeptab/models/experimental/__init__.py @@ -0,0 +1,23 @@ +""" +Experimental models — subject to change without notice. + +Import these explicitly to signal that you accept the instability: + + from deeptab.models.experimental import ModernNCA, Tangos, Trompt +""" + +from .modern_nca import ModernNCAClassifier, ModernNCALSS, ModernNCARegressor +from .tangos import TangosClassifier, TangosLSS, TangosRegressor +from .trompt import TromptClassifier, TromptLSS, TromptRegressor + +__all__ = [ + "ModernNCAClassifier", + "ModernNCALSS", + "ModernNCARegressor", + "TangosClassifier", + "TangosLSS", + "TangosRegressor", + "TromptClassifier", + "TromptLSS", + "TromptRegressor", +] diff --git a/deeptab/models/modern_nca.py b/deeptab/models/experimental/modern_nca.py similarity index 86% rename from deeptab/models/modern_nca.py rename to deeptab/models/experimental/modern_nca.py index 4b78479..1993848 100644 --- a/deeptab/models/modern_nca.py +++ b/deeptab/models/experimental/modern_nca.py @@ -1,9 +1,9 @@ -from ..base_models.modern_nca import ModernNCA -from ..configs.modernnca_config import DefaultModernNCAConfig -from ..utils.docstring_generator import generate_docstring -from .utils.sklearn_base_classifier import SklearnBaseClassifier -from .utils.sklearn_base_lss import SklearnBaseLSS -from .utils.sklearn_base_regressor import SklearnBaseRegressor +from ...base_models.modern_nca import ModernNCA +from ...configs.modernnca_config import DefaultModernNCAConfig +from ...utils.docstring_generator import generate_docstring +from ..utils.sklearn_base_classifier import SklearnBaseClassifier +from ..utils.sklearn_base_lss import SklearnBaseLSS +from ..utils.sklearn_base_regressor import SklearnBaseRegressor class ModernNCARegressor(SklearnBaseRegressor): diff --git a/deeptab/models/tangos.py b/deeptab/models/experimental/tangos.py similarity index 86% rename from deeptab/models/tangos.py rename to deeptab/models/experimental/tangos.py index abbd437..2b5d6c6 100644 --- a/deeptab/models/tangos.py +++ b/deeptab/models/experimental/tangos.py @@ -1,9 +1,9 @@ -from ..base_models.tangos import Tangos -from ..configs.tangos_config import DefaultTangosConfig -from ..utils.docstring_generator import generate_docstring -from .utils.sklearn_base_classifier import SklearnBaseClassifier -from .utils.sklearn_base_lss import SklearnBaseLSS -from .utils.sklearn_base_regressor import SklearnBaseRegressor +from ...base_models.tangos import Tangos +from ...configs.tangos_config import DefaultTangosConfig +from ...utils.docstring_generator import generate_docstring +from ..utils.sklearn_base_classifier import SklearnBaseClassifier +from ..utils.sklearn_base_lss import SklearnBaseLSS +from ..utils.sklearn_base_regressor import SklearnBaseRegressor class TangosRegressor(SklearnBaseRegressor): diff --git a/deeptab/models/trompt.py b/deeptab/models/experimental/trompt.py similarity index 85% rename from deeptab/models/trompt.py rename to deeptab/models/experimental/trompt.py index d827a99..40b4360 100644 --- a/deeptab/models/trompt.py +++ b/deeptab/models/experimental/trompt.py @@ -1,9 +1,9 @@ -from ..base_models.trompt import Trompt -from ..configs.trompt_config import DefaultTromptConfig -from ..utils.docstring_generator import generate_docstring -from .utils.sklearn_base_classifier import SklearnBaseClassifier -from .utils.sklearn_base_lss import SklearnBaseLSS -from .utils.sklearn_base_regressor import SklearnBaseRegressor +from ...base_models.trompt import Trompt +from ...configs.trompt_config import DefaultTromptConfig +from ...utils.docstring_generator import generate_docstring +from ..utils.sklearn_base_classifier import SklearnBaseClassifier +from ..utils.sklearn_base_lss import SklearnBaseLSS +from ..utils.sklearn_base_regressor import SklearnBaseRegressor class TromptRegressor(SklearnBaseRegressor): diff --git a/tests/test_model_exports.py b/tests/test_model_exports.py new file mode 100644 index 0000000..f245cfc --- /dev/null +++ b/tests/test_model_exports.py @@ -0,0 +1,160 @@ +""" +Tests verifying the stable / experimental model export split. + +Stable models must be importable from ``deeptab.models``. +Experimental models must be importable from ``deeptab.models.experimental`` +and must *not* appear in ``deeptab.models.__all__``. +""" + +import importlib + +import pytest + +# --------------------------------------------------------------------------- +# Expected exports +# --------------------------------------------------------------------------- + +STABLE_CLASSES = [ + "AutoIntClassifier", + "AutoIntLSS", + "AutoIntRegressor", + "ENODEClassifier", + "ENODELSS", + "ENODERegressor", + "FTTransformerClassifier", + "FTTransformerLSS", + "FTTransformerRegressor", + "MambAttentionClassifier", + "MambAttentionLSS", + "MambAttentionRegressor", + "MambaTabClassifier", + "MambaTabLSS", + "MambaTabRegressor", + "MambularClassifier", + "MambularLSS", + "MambularRegressor", + "MLPClassifier", + "MLPLSS", + "MLPRegressor", + "NDTFClassifier", + "NDTFLSS", + "NDTFRegressor", + "NODEClassifier", + "NODELSS", + "NODERegressor", + "ResNetClassifier", + "ResNetLSS", + "ResNetRegressor", + "SAINTClassifier", + "SAINTLSS", + "SAINTRegressor", + "TabMClassifier", + "TabMLSS", + "TabMRegressor", + "TabRClassifier", + "TabRLSS", + "TabRRegressor", + "TabTransformerClassifier", + "TabTransformerLSS", + "TabTransformerRegressor", + "TabulaRNNClassifier", + "TabulaRNNLSS", + "TabulaRNNRegressor", +] + +EXPERIMENTAL_CLASSES = [ + "ModernNCAClassifier", + "ModernNCALSS", + "ModernNCARegressor", + "TangosClassifier", + "TangosLSS", + "TangosRegressor", + "TromptClassifier", + "TromptLSS", + "TromptRegressor", +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _import(module_path: str, attr: str): + mod = importlib.import_module(module_path) + return getattr(mod, attr) + + +# --------------------------------------------------------------------------- +# Stable models +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("class_name", STABLE_CLASSES) +def test_stable_model_importable(class_name: str): + """Every stable class is importable from deeptab.models.""" + cls = _import("deeptab.models", class_name) + assert cls is not None + + +@pytest.mark.parametrize("class_name", STABLE_CLASSES) +def test_stable_model_in_all(class_name: str): + """Every stable class is listed in deeptab.models.__all__.""" + import deeptab.models as m + + assert class_name in m.__all__, f"{class_name!r} missing from deeptab.models.__all__" + + +# --------------------------------------------------------------------------- +# Experimental models +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_model_importable(class_name: str): + """Every experimental class is importable from deeptab.models.experimental.""" + cls = _import("deeptab.models.experimental", class_name) + assert cls is not None + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_model_in_experimental_all(class_name: str): + """Every experimental class is listed in deeptab.models.experimental.__all__.""" + import deeptab.models.experimental as exp + + assert class_name in exp.__all__, f"{class_name!r} missing from deeptab.models.experimental.__all__" + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_model_not_in_stable_all(class_name: str): + """Experimental classes must not leak into deeptab.models.__all__.""" + import deeptab.models as m + + assert class_name not in m.__all__, f"{class_name!r} should not be in deeptab.models.__all__" + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + + +def test_registry_stable_import_paths(): + """All stable entries in MODEL_REGISTRY have import_path == 'deeptab.models'.""" + from deeptab.models._registry import MODEL_REGISTRY + + for name, info in MODEL_REGISTRY.items(): + if info.status == "stable": + assert info.import_path == "deeptab.models", ( + f"{name}: expected import_path 'deeptab.models', got {info.import_path!r}" + ) + + +def test_registry_experimental_import_paths(): + """All experimental entries have import_path == 'deeptab.models.experimental'.""" + from deeptab.models._registry import MODEL_REGISTRY + + for name, info in MODEL_REGISTRY.items(): + if info.status == "experimental": + assert info.import_path == "deeptab.models.experimental", ( + f"{name}: expected 'deeptab.models.experimental', got {info.import_path!r}" + ) From c59273195ff0adf8caa4e95f8a880e00cef7c17c Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 15:58:08 +0200 Subject: [PATCH 02/22] fix: resolve stale site-packages --- justfile | 3 ++- pyproject.toml | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/justfile b/justfile index 65938f3..e3e5aeb 100644 --- a/justfile +++ b/justfile @@ -2,9 +2,10 @@ default: @just --list --unsorted -# install dependencies and set up all pre-commit hooks +# install dependencies, editable package, and set up all pre-commit hooks install: poetry install + poetry run pip install -e . --quiet poetry run pre-commit install --hook-type commit-msg --hook-type pre-commit --hook-type pre-push # update dependencies and pre-commit hook revisions diff --git a/pyproject.toml b/pyproject.toml index 0f33720..67c8a2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,10 @@ repository = "https://github.com/OpenTabular/deeptab" package = "https://pypi.org/project/deeptab/" +# test configuration +[tool.pytest.ini_options] +pythonpath = ["."] + # code quality tools [tool.pyright] include = ["deeptab", "tests"] From ad23aafd879039181013ddccd6e2ded69ef7873c Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 15:59:33 +0200 Subject: [PATCH 03/22] fix: import error after experimental namespace changes --- deeptab/models/__init__.py | 32 +++++++++++++++++++ deeptab/models/experimental/modern_nca.py | 6 ++-- deeptab/models/experimental/tangos.py | 6 ++-- deeptab/models/experimental/trompt.py | 6 ++-- tests/test_model_exports.py | 39 +++++++++++++++++++++++ 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/deeptab/models/__init__.py b/deeptab/models/__init__.py index b23d90b..48838d2 100644 --- a/deeptab/models/__init__.py +++ b/deeptab/models/__init__.py @@ -1,3 +1,6 @@ +import importlib +import warnings + from .autoint import AutoIntClassifier, AutoIntLSS, AutoIntRegressor from .enode import ENODELSS, ENODEClassifier, ENODERegressor from .fttransformer import ( @@ -79,3 +82,32 @@ "TabulaRNNLSS", "TabulaRNNRegressor", ] + +# --------------------------------------------------------------------------- +# Backwards-compatibility shim for experimental models +# --------------------------------------------------------------------------- + +_EXPERIMENTAL_COMPAT: dict[str, str] = { + "ModernNCAClassifier": "deeptab.models.experimental", + "ModernNCALSS": "deeptab.models.experimental", + "ModernNCARegressor": "deeptab.models.experimental", + "TangosClassifier": "deeptab.models.experimental", + "TangosLSS": "deeptab.models.experimental", + "TangosRegressor": "deeptab.models.experimental", + "TromptClassifier": "deeptab.models.experimental", + "TromptLSS": "deeptab.models.experimental", + "TromptRegressor": "deeptab.models.experimental", +} + + +def __getattr__(name: str): + if name in _EXPERIMENTAL_COMPAT: + new_path = _EXPERIMENTAL_COMPAT[name] + warnings.warn( + f"{name!r} has moved to '{new_path}'. Update your import: from {new_path} import {name}", + DeprecationWarning, + stacklevel=2, + ) + mod = importlib.import_module(new_path) + return getattr(mod, name) + raise AttributeError(f"module 'deeptab.models' has no attribute {name!r}") diff --git a/deeptab/models/experimental/modern_nca.py b/deeptab/models/experimental/modern_nca.py index 1993848..6530e18 100644 --- a/deeptab/models/experimental/modern_nca.py +++ b/deeptab/models/experimental/modern_nca.py @@ -14,7 +14,7 @@ class ModernNCARegressor(SklearnBaseRegressor): with the default ModernNCA configuration. """, examples=""" - >>> from deeptab.models import ModernNCARegressor + >>> from deeptab.models.experimental import ModernNCARegressor >>> model = ModernNCARegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -34,7 +34,7 @@ class ModernNCAClassifier(SklearnBaseClassifier): with the default ModernNCA configuration. """, examples=""" - >>> from deeptab.models import ModernNCAClassifier + >>> from deeptab.models.experimental import ModernNCAClassifier >>> model = ModernNCAClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -54,7 +54,7 @@ class ModernNCALSS(SklearnBaseLSS): with the default ModernNCA configuration. """, examples=""" - >>> from deeptab.models import ModernNCALSS + >>> from deeptab.models.experimental import ModernNCALSS >>> model = ModernNCALSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family='normal') >>> preds = model.predict(X_test) diff --git a/deeptab/models/experimental/tangos.py b/deeptab/models/experimental/tangos.py index 2b5d6c6..9502791 100644 --- a/deeptab/models/experimental/tangos.py +++ b/deeptab/models/experimental/tangos.py @@ -14,7 +14,7 @@ class TangosRegressor(SklearnBaseRegressor): with the default Tangos configuration. """, examples=""" - >>> from deeptab.models import TangosRegressor + >>> from deeptab.models.experimental import TangosRegressor >>> model = TangosRegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -34,7 +34,7 @@ class TangosClassifier(SklearnBaseClassifier): with the default Tangos configuration. """, examples=""" - >>> from deeptab.models import TangosClassifier + >>> from deeptab.models.experimental import TangosClassifier >>> model = TangosClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -54,7 +54,7 @@ class TangosLSS(SklearnBaseLSS): with the default Tangos configuration. """, examples=""" - >>> from deeptab.models import TangosLSS + >>> from deeptab.models.experimental import TangosLSS >>> model = TangosLSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family='normal') >>> preds = model.predict(X_test) diff --git a/deeptab/models/experimental/trompt.py b/deeptab/models/experimental/trompt.py index 40b4360..3109cae 100644 --- a/deeptab/models/experimental/trompt.py +++ b/deeptab/models/experimental/trompt.py @@ -15,7 +15,7 @@ class and uses the Trompt model with the default Trompt configuration. """, examples=""" - >>> from deeptab.models import TromptRegressor + >>> from deeptab.models.experimental import TromptRegressor >>> model = TromptRegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -33,7 +33,7 @@ class TromptClassifier(SklearnBaseClassifier): """Trompt Classifier. This class extends the SklearnBaseClassifier class and uses the Trompt model with the default Trompt configuration.""", examples=""" - >>> from deeptab.models import TromptClassifier + >>> from deeptab.models.experimental import TromptClassifier >>> model = TromptClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) @@ -52,7 +52,7 @@ class TromptLSS(SklearnBaseLSS): This class extends the SklearnBaseLSS class and uses the Trompt model with the default Trompt configuration.""", examples=""" - >>> from deeptab.models import TromptLSS + >>> from deeptab.models.experimental import TromptLSS >>> model = TromptLSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family="normal") >>> preds = model.predict(X_test) diff --git a/tests/test_model_exports.py b/tests/test_model_exports.py index f245cfc..4d1e58f 100644 --- a/tests/test_model_exports.py +++ b/tests/test_model_exports.py @@ -158,3 +158,42 @@ def test_registry_experimental_import_paths(): assert info.import_path == "deeptab.models.experimental", ( f"{name}: expected 'deeptab.models.experimental', got {info.import_path!r}" ) + + +# --------------------------------------------------------------------------- +# Deprecation warnings +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_experimental_via_stable_emits_deprecation(class_name: str): + """Accessing an experimental class via deeptab.models emits DeprecationWarning.""" + import deeptab.models as m + + with pytest.warns(DeprecationWarning, match=class_name): + cls = getattr(m, class_name) + assert cls is not None + + +@pytest.mark.parametrize("class_name", EXPERIMENTAL_CLASSES) +def test_deprecation_message_contains_new_path(class_name: str): + """Deprecation warning message includes the correct new import path.""" + import warnings + + import deeptab.models as m + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + getattr(m, class_name) + + assert any("deeptab.models.experimental" in str(w.message) for w in caught), ( + f"Expected new path in deprecation message for {class_name!r}" + ) + + +def test_unknown_attribute_raises(): + """Accessing a truly unknown attribute on deeptab.models still raises AttributeError.""" + import deeptab.models as m + + with pytest.raises(AttributeError): + _ = m.ThisDoesNotExist From a993829d907b69742566c30f077e389ef794c76b Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 16:03:14 +0200 Subject: [PATCH 04/22] docs: docstring added to the models ndtf, tabularnn and mambatab --- deeptab/models/mambatab.py | 24 ++++++++-- deeptab/models/ndtf.py | 28 +++++++++-- deeptab/models/tabularnn.py | 94 ++++++------------------------------- 3 files changed, 59 insertions(+), 87 deletions(-) diff --git a/deeptab/models/mambatab.py b/deeptab/models/mambatab.py index d6f98ba..7885a08 100644 --- a/deeptab/models/mambatab.py +++ b/deeptab/models/mambatab.py @@ -13,7 +13,13 @@ class MambaTabRegressor(SklearnBaseRegressor): MambaTab regressor. This class extends the SklearnBaseRegressor class and uses the MambaTab model with the default MambaTab configuration. """, - examples="", + examples=""" + >>> from deeptab.models import MambaTabRegressor + >>> model = MambaTabRegressor(d_model=64, n_layers=2) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -27,7 +33,13 @@ class MambaTabClassifier(SklearnBaseClassifier): MambaTab classifier. This class extends the SklearnBaseClassifier class and uses the MambaTab model with the default MambaTab configuration. """, - examples="", + examples=""" + >>> from deeptab.models import MambaTabClassifier + >>> model = MambaTabClassifier(d_model=64, n_layers=2) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -41,7 +53,13 @@ class MambaTabLSS(SklearnBaseLSS): MambaTab LSS for distributional regression. This class extends the SklearnBaseLSS class and uses the MambaTab model with the default MambaTab configuration. """, - examples="", + examples=""" + >>> from deeptab.models import MambaTabLSS + >>> model = MambaTabLSS(d_model=64, n_layers=2) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): diff --git a/deeptab/models/ndtf.py b/deeptab/models/ndtf.py index dc02d71..ca7b552 100644 --- a/deeptab/models/ndtf.py +++ b/deeptab/models/ndtf.py @@ -13,7 +13,13 @@ class NDTFRegressor(SklearnBaseRegressor): Neural Decision Forest regressor. This class extends the SklearnBaseRegressor class and uses the NDTF model with the default NDTF configuration. """, - examples="", + examples=""" + >>> from deeptab.models import NDTFRegressor + >>> model = NDTFRegressor(n_ensembles=12, max_depth=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -24,10 +30,16 @@ class NDTFClassifier(SklearnBaseClassifier): __doc__ = generate_docstring( DefaultNDTFConfig, model_description=""" - Neural Decision Forest classifier. This class extends the SklearnBasClassifier class and uses the NDTF model + Neural Decision Forest classifier. This class extends the SklearnBaseClassifier class and uses the NDTF model with the default NDTF configuration. """, - examples="", + examples=""" + >>> from deeptab.models import NDTFClassifier + >>> model = NDTFClassifier(n_ensembles=12, max_depth=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): @@ -38,10 +50,16 @@ class NDTFLSS(SklearnBaseLSS): __doc__ = generate_docstring( DefaultNDTFConfig, model_description=""" - Neural Decision Forest for distributional regressor. This class extends the SklearnBaseLSS class and uses the NDTF model + Neural Decision Forest for distributional regression. This class extends the SklearnBaseLSS class and uses the NDTF model with the default NDTF configuration. """, - examples="", + examples=""" + >>> from deeptab.models import NDTFLSS + >>> model = NDTFLSS(n_ensembles=12, max_depth=8) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, ) def __init__(self, **kwargs): diff --git a/deeptab/models/tabularnn.py b/deeptab/models/tabularnn.py index a14910a..8febf9e 100644 --- a/deeptab/models/tabularnn.py +++ b/deeptab/models/tabularnn.py @@ -49,85 +49,21 @@ def __init__(self, **kwargs): class TabulaRNNLSS(SklearnBaseLSS): - """RNN LSS. This class extends the SklearnBaseLSS class and uses the TabulaRNN model with the default TabulaRNN - configuration. - - The accepted arguments to the TabulaRNNLSS class include both the attributes in the DefaultTabulaRNNConfig dataclass - and the parameters for the Preprocessor class. - - Parameters - ---------- - lr : float, default=1e-04 - Learning rate for the optimizer. - model_type : str, default="RNN" - type of model, one of "RNN", "LSTM", "GRU" - family : str, default=None - Distributional family to be used for the model. - lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. - weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. - lr_factor : float, default=0.1 - Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the transformer. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() - Activation function for the transformer. - embedding_activation : callable, default=nn.Identity() - Activation function for numerical embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. - head_dropout : float, default=0.5 - Dropout rate for the head layers. - head_skip_layers : bool, default=False - Whether to skip layers in the head. - head_activation : callable, default=nn.SELU() - Activation function for the head layers. - head_use_batch_norm : bool, default=False - Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="cls" - Pooling method to be used ('cls', 'avg', etc.). - norm_first : bool, default=False - Whether to apply normalization before other operations in each transformer block. - bias : bool, default=True - Whether to use bias in the linear layers. - rnn_activation : callable, default=nn.SELU() - Activation function for the transformer layers. - bidirectional : bool, default=False. - Whether to process data bidirectionally - cat_encoding : str, default="int" - Encoding method for categorical features. - n_bins : int, default=50 - The number of bins to use for numerical feature binning. This parameter is relevant - only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. - numerical_preprocessing : str, default="ple" - The preprocessing strategy for numerical features. Valid options are - 'binning', 'one_hot', 'standardization', and 'normalization'. - use_decision_tree_bins : bool, default=False - If True, uses decision tree regression/classification to determine - optimal bin edges for numerical feature binning. This parameter is - relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. - binning_strategy : str, default="uniform" - Defines the strategy for binning numerical features. Options include 'uniform', - 'quantile', or other sklearn-compatible strategies. - cat_cutoff : float or int, default=0.03 - Indicates the cutoff after which integer values are treated as categorical. - If float, it's treated as a percentage. If int, it's the maximum number of - unique values for a column to be considered categorical. - treat_all_integers_as_numerical : bool, default=False - If True, all integer columns will be treated as numerical, regardless - of their unique value count or proportion. - degree : int, default=3 - The degree of the polynomial features to be used in preprocessing. - knots : int, default=12 - The number of knots to be used in spline transformations. - """ + __doc__ = generate_docstring( + DefaultTabulaRNNConfig, + model_description=""" + TabulaRNN for distributional regression. This class extends the SklearnBaseLSS + class and uses the TabulaRNN model with the default TabulaRNN configuration. + Supports RNN, LSTM, GRU, mLSTM, and sLSTM architectures. + """, + examples=""" + >>> from deeptab.models import TabulaRNNLSS + >>> model = TabulaRNNLSS(model_type='LSTM', d_model=128, n_layers=4) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, + ) def __init__(self, **kwargs): super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs) From 82ba57a5aa1418b3e1f7354212c4f7ce609035a7 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 16:12:04 +0200 Subject: [PATCH 05/22] docs: update docs for stable/experimental model split --- docs/api/models/Models.rst | 29 +++-- docs/api/models/index.rst | 24 ++++ .../developer_guide/model_promotion_policy.md | 14 +- docs/examples/experimental.md | 121 ++++++++++++++++++ docs/index.rst | 1 + docs/key_concepts.md | 21 +++ docs/overview.md | 15 ++- 7 files changed, 208 insertions(+), 17 deletions(-) create mode 100644 docs/examples/experimental.md diff --git a/docs/api/models/Models.rst b/docs/api/models/Models.rst index fd02d9f..7b64f5e 100644 --- a/docs/api/models/Models.rst +++ b/docs/api/models/Models.rst @@ -169,39 +169,50 @@ deeptab.models :members: :undoc-members: -.. autoclass:: deeptab.models.ModernNCAClassifier + +Experimental Models +------------------- + +.. warning:: + + The classes below live in ``deeptab.models.experimental``. Their API may + change without a deprecation cycle. Import them explicitly:: + + from deeptab.models.experimental import ModernNCAClassifier + +.. autoclass:: deeptab.models.experimental.ModernNCAClassifier :members: :undoc-members: -.. autoclass:: deeptab.models.ModernNCARegressor +.. autoclass:: deeptab.models.experimental.ModernNCARegressor :members: :undoc-members: -.. autoclass:: deeptab.models.ModernNCALSS +.. autoclass:: deeptab.models.experimental.ModernNCALSS :members: :undoc-members: -.. autoclass:: deeptab.models.TangosClassifier +.. autoclass:: deeptab.models.experimental.TangosClassifier :members: :undoc-members: -.. autoclass:: deeptab.models.TangosRegressor +.. autoclass:: deeptab.models.experimental.TangosRegressor :members: :undoc-members: -.. autoclass:: deeptab.models.TangosLSS +.. autoclass:: deeptab.models.experimental.TangosLSS :members: :undoc-members: -.. autoclass:: deeptab.models.TromptClassifier +.. autoclass:: deeptab.models.experimental.TromptClassifier :members: :undoc-members: -.. autoclass:: deeptab.models.TromptRegressor +.. autoclass:: deeptab.models.experimental.TromptRegressor :members: :undoc-members: -.. autoclass:: deeptab.models.TromptLSS +.. autoclass:: deeptab.models.experimental.TromptLSS :members: :undoc-members: diff --git a/docs/api/models/index.rst b/docs/api/models/index.rst index 20fe846..84897da 100644 --- a/docs/api/models/index.rst +++ b/docs/api/models/index.rst @@ -137,6 +137,30 @@ Modules Description :class:`SklearnBaseRegressor` Base class for regression tasks. ======================================= ======================================================================================================= +Experimental Models +------------------- + +.. warning:: + + Experimental models are available from ``deeptab.models.experimental``. + Their API may change without a deprecation cycle. + +.. currentmodule:: deeptab.models.experimental + +======================================= =========================================================================== +Modules Description +======================================= =========================================================================== +:class:`ModernNCAClassifier` ModernNCA for classification tasks. +:class:`ModernNCARegressor` ModernNCA for regression tasks. +:class:`ModernNCALSS` ModernNCA for distributional tasks. +:class:`TangosClassifier` Tangos for classification tasks. +:class:`TangosRegressor` Tangos for regression tasks. +:class:`TangosLSS` Tangos for distributional tasks. +:class:`TromptClassifier` Trompt for classification tasks. +:class:`TromptRegressor` Trompt for regression tasks. +:class:`TromptLSS` Trompt for distributional tasks. +======================================= =========================================================================== + .. toctree:: :maxdepth: 1 diff --git a/docs/developer_guide/model_promotion_policy.md b/docs/developer_guide/model_promotion_policy.md index 2f70f70..7690110 100644 --- a/docs/developer_guide/model_promotion_policy.md +++ b/docs/developer_guide/model_promotion_policy.md @@ -54,16 +54,20 @@ No open GitHub issues labelled `bug` for the model may describe a failure in a c ### 7. Registry -A config class must exist in `deeptab/configs/` and be exported from `deeptab/configs/__init__.py`. The model must be exported from `deeptab/models/__init__.py` and listed in `deeptab/utils/config_mapper.py`. +A config class must exist in `deeptab/configs/` and be exported from `deeptab/configs/__init__.py`. The model must be exported from `deeptab/models/experimental/__init__.py` while experimental, or from `deeptab/models/__init__.py` once stable, and listed in `deeptab/utils/config_mapper.py`. The `MODEL_REGISTRY` in `deeptab/models/_registry.py` must contain an entry with the correct `status` and `import_path`. ## Promotion PR Open a PR titled `feat(): promote to stable`. The PR must: -1. Remove any `.. experimental::` admonition from the model's doc page. -2. Remove any `ExperimentalWarning` raised in `__init__` or `fit`. -3. Remove the experimental badge from the API reference entry. -4. Add the model to the changelog under `### Promoted to Stable`. +1. Move the model file from `deeptab/models/experimental/` to `deeptab/models/` using `git mv`. +2. Update relative imports in the moved file (reduce one `..` level). +3. Remove the model from `deeptab/models/experimental/__init__.py` and its `__all__`. +4. Add the model to `deeptab/models/__init__.py` imports and `__all__`. +5. Update `MODEL_REGISTRY` in `deeptab/models/_registry.py`: change `status` to `"stable"` and `import_path` to `"deeptab.models"`. +6. Remove any `.. experimental::` admonition from the model's doc page. +7. Remove the experimental badge from the API reference entry. +8. Add the model to the changelog under `### Promoted to Stable`. Approval requires at least one maintainer review beyond the author. Use the promotion checklist in the PR template to track each requirement. diff --git a/docs/examples/experimental.md b/docs/examples/experimental.md new file mode 100644 index 0000000..0b0c8c7 --- /dev/null +++ b/docs/examples/experimental.md @@ -0,0 +1,121 @@ +# Using Experimental Models + +Experimental models live in `deeptab.models.experimental`. Their API may change +without a deprecation cycle, but they are otherwise fully functional and follow +the same `fit` / `predict` / `evaluate` interface as stable models. + +```{warning} +Experimental models are not covered by semantic versioning guarantees. +Pin your DeepTab version (`deeptab==x.y.z`) if you use them in production code +to avoid unexpected breakage after upgrades. +``` + +## Import path + +```python +# stable models — imported directly from deeptab.models +from deeptab.models import MambularClassifier + +# experimental models — always import from deeptab.models.experimental +from deeptab.models.experimental import TromptClassifier, ModernNCARegressor, TangosLSS +``` + +Importing an experimental class directly from `deeptab.models` (the old path) +still works but raises a `DeprecationWarning`: + +```python +# raises DeprecationWarning — update the import +from deeptab.models import TromptClassifier +``` + +--- + +## End-to-end example — Trompt for classification + +### Setup + +```python +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +from deeptab.models.experimental import TromptClassifier +``` + +### Generate data + +```python +np.random.seed(42) + +n_samples, n_features, n_classes = 800, 6, 3 +X = np.random.randn(n_samples, n_features) +y = np.random.randint(0, n_classes, size=n_samples) + +df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(n_features)]) +X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=42) +``` + +### Train + +```python +model = TromptClassifier() +model.fit(X_train, y_train, max_epochs=10) +``` + +### Evaluate + +```python +metrics = model.evaluate(X_test, y_test) +print(metrics) +``` + +### Predict + +```python +preds = model.predict(X_test) +proba = model.predict_proba(X_test) +``` + +--- + +## End-to-end example — ModernNCA for regression + +```python +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +from deeptab.models.experimental import ModernNCARegressor + +np.random.seed(0) +n_samples, n_features = 800, 5 +X = np.random.randn(n_samples, n_features) +y = X @ np.random.randn(n_features) + np.random.randn(n_samples) * 0.1 + +df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(n_features)]) +X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0) + +model = ModernNCARegressor(d_model=64, n_layers=4) +model.fit(X_train, y_train, max_epochs=10) + +metrics = model.evaluate(X_test, y_test) +print(metrics) +``` + +--- + +## Switching between experimental and stable + +The API is identical — only the import path changes. When a model is promoted to +stable, update the import and nothing else: + +```python +# Before promotion +from deeptab.models.experimental import TromptClassifier + +# After promotion (no other code changes needed) +from deeptab.models import TromptClassifier +``` + +See [Model Promotion Policy](../developer_guide/model_promotion_policy) for the +criteria a model must meet before it moves to stable. diff --git a/docs/index.rst b/docs/index.rst index ff6df57..874df7c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,7 @@ examples/classification examples/regression examples/distributional + examples/experimental .. toctree:: :name: API Reference diff --git a/docs/key_concepts.md b/docs/key_concepts.md index 25e498d..51db1e1 100644 --- a/docs/key_concepts.md +++ b/docs/key_concepts.md @@ -33,6 +33,27 @@ from deeptab.models import MambularRegressor # regression from deeptab.models import MambularLSS # distributional regression ``` +## Stable vs experimental models + +DeepTab ships models at two tiers: + +| Tier | Import path | Guarantee | +| ---------------- | --------------------------------------------- | ------------------------------------------- | +| **Stable** | `from deeptab.models import ...` | Public API frozen under semantic versioning | +| **Experimental** | `from deeptab.models.experimental import ...` | May change without a deprecation cycle | + +Always use the explicit experimental import path to signal that you accept the instability: + +```python +# stable +from deeptab.models import FTTransformerClassifier + +# experimental — explicit path required +from deeptab.models.experimental import TromptClassifier +``` + +See [Using experimental models](examples/experimental) for a full worked example. + ## Configuring hyperparameters Every model has a corresponding config class in `deeptab.configs` that documents all available hyperparameters. You can either pass hyperparameters directly to the constructor or via a config object: diff --git a/docs/overview.md b/docs/overview.md index 7b04f5f..2949dda 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -15,6 +15,8 @@ Tabular data is the most common format in applied machine learning, yet most dee All models support regression, classification, and distributional regression out of the box. Import them as `Regressor`, `Classifier`, or `LSS`. +### Stable + | Model | Architecture | Reference | | ---------------- | ---------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | | `Mambular` | Sequential Mamba (SSM) blocks for tabular data | [Thielmann et al. (2024)](https://arxiv.org/abs/2408.06291) | @@ -32,9 +34,16 @@ All models support regression, classification, and distributional regression out | `TabulaRNN` | Recurrent neural network for tabular data | [Thielmann et al. (2025)](https://arxiv.org/pdf/2411.17207) | | `ENODE` | Extended NODE variant | — | | `AutoInt` | Automatic feature interaction via attention | — | -| `ModernNCA` | Modern neural classification architecture | — | -| `Trompt` | Tabular-specific prompting model | — | -| `TANGOS` | Tabular model with graph-based structure | — | + +### Experimental + +Experimental models are imported from `deeptab.models.experimental`. Their API may change without a deprecation cycle. See [Using experimental models](examples/experimental) for a worked example. + +| Model | Architecture | Reference | +| ----------- | ----------------------------------------- | --------- | +| `ModernNCA` | Modern neural classification architecture | — | +| `Trompt` | Tabular-specific prompting model | — | +| `Tangos` | Tabular model with graph-based structure | — | ## Next steps From 34c918854a63df82561cc848d13a83f606c1bccf Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 16:27:30 +0200 Subject: [PATCH 06/22] docs: add individual model pages, stable/experimental split tables --- docs/api/models/autoint.rst | 43 ++++++++++++++++++++++++++++++ docs/api/models/enode.rst | 40 +++++++++++++++++++++++++++ docs/api/models/fttransformer.rst | 39 +++++++++++++++++++++++++++ docs/api/models/index.rst | 21 +++++++++++++++ docs/api/models/mambatab.rst | 38 ++++++++++++++++++++++++++ docs/api/models/mambattention.rst | 39 +++++++++++++++++++++++++++ docs/api/models/mambular.rst | 40 +++++++++++++++++++++++++++ docs/api/models/mlp.rst | 39 +++++++++++++++++++++++++++ docs/api/models/ndtf.rst | 40 +++++++++++++++++++++++++++ docs/api/models/node.rst | 40 +++++++++++++++++++++++++++ docs/api/models/resnet.rst | 38 ++++++++++++++++++++++++++ docs/api/models/saint.rst | 40 +++++++++++++++++++++++++++ docs/api/models/tabm.rst | 38 ++++++++++++++++++++++++++ docs/api/models/tabr.rst | 40 +++++++++++++++++++++++++++ docs/api/models/tabtransformer.rst | 39 +++++++++++++++++++++++++++ docs/api/models/tabularrnn.rst | 41 ++++++++++++++++++++++++++++ docs/examples/classification.md | 24 +++++++++++++++++ docs/examples/distributional.md | 24 +++++++++++++++++ docs/examples/regression.md | 24 +++++++++++++++++ 19 files changed, 687 insertions(+) create mode 100644 docs/api/models/autoint.rst create mode 100644 docs/api/models/enode.rst create mode 100644 docs/api/models/fttransformer.rst create mode 100644 docs/api/models/mambatab.rst create mode 100644 docs/api/models/mambattention.rst create mode 100644 docs/api/models/mambular.rst create mode 100644 docs/api/models/mlp.rst create mode 100644 docs/api/models/ndtf.rst create mode 100644 docs/api/models/node.rst create mode 100644 docs/api/models/resnet.rst create mode 100644 docs/api/models/saint.rst create mode 100644 docs/api/models/tabm.rst create mode 100644 docs/api/models/tabr.rst create mode 100644 docs/api/models/tabtransformer.rst create mode 100644 docs/api/models/tabularrnn.rst diff --git a/docs/api/models/autoint.rst b/docs/api/models/autoint.rst new file mode 100644 index 0000000..fa8a4aa --- /dev/null +++ b/docs/api/models/autoint.rst @@ -0,0 +1,43 @@ +AutoInt +======= + +Automatic feature Interaction learning via multi-head self-attention on feature +embeddings. Each input feature is projected into an embedding and the +embeddings are passed through stacked multi-head attention layers. Residual +connections allow the model to combine the original feature representation with +the interaction-augmented representation, making the learned interactions +explicitly additive. + +When to Use +----------- + +When capturing explicit pairwise and higher-order feature interactions is the +primary modelling goal. Historically strong in click-through-rate prediction +and recommendation system benchmarks. + +Limitations +----------- + +- Performance is generally comparable to FTTransformer on most generic tabular + benchmarks; FTTransformer is often a simpler first choice. +- Less effective for very high-dimensional sparse feature spaces compared to + factorisation-machine-based methods. +- The additional residual interaction terms add minor overhead vs plain + Transformer models. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: AutoIntRegressor + :members: + :undoc-members: + +.. autoclass:: AutoIntClassifier + :members: + :undoc-members: + +.. autoclass:: AutoIntLSS + :members: + :undoc-members: diff --git a/docs/api/models/enode.rst b/docs/api/models/enode.rst new file mode 100644 index 0000000..a1270ba --- /dev/null +++ b/docs/api/models/enode.rst @@ -0,0 +1,40 @@ +ENODE +===== + +Extended Neural Oblivious Decision Ensembles. ENODE builds on :doc:`node` by +adding explicit feature embedding layers before the decision ensemble. These +embedding layers transform raw input features into richer representations before +they are fed into the differentiable decision trees, improving performance when +the raw feature space is noisy or heterogeneous. + +When to Use +----------- + +Upgrade from NODE when raw feature quality is poor, the data is heterogeneous, +or vanilla NODE underfits. The embedding layers add a small representational +overhead that often pays off on real-world datasets. + +Limitations +----------- + +- Inherits the same fundamental limitations as NODE (high memory, slow training). +- Increased model size compared to plain NODE. +- May be harder to interpret than NODE because the input to the decision + ensemble is no longer the raw feature space. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: ENODERegressor + :members: + :undoc-members: + +.. autoclass:: ENODEClassifier + :members: + :undoc-members: + +.. autoclass:: ENODELSS + :members: + :undoc-members: diff --git a/docs/api/models/fttransformer.rst b/docs/api/models/fttransformer.rst new file mode 100644 index 0000000..f37fc83 --- /dev/null +++ b/docs/api/models/fttransformer.rst @@ -0,0 +1,39 @@ +FTTransformer +============= + +Feature Tokenizer + Transformer. Each input feature — numerical or categorical — +is mapped to a dense token embedding, and the resulting sequence of tokens is +processed through a stack of standard Transformer encoder layers. A ``[CLS]`` +token is prepended and used to produce the final prediction. + +When to Use +----------- + +Strong general-purpose model. Particularly effective on mixed datasets with both +numerical and categorical features where pairwise feature interactions are +important. Typically the first Transformer baseline to try. + +Limitations +----------- + +- Higher memory and compute cost relative to MLP and ResNet. +- Tends to overfit on very small datasets (under ~500 samples); consider adding + dropout or reducing depth. +- Longer training time than simpler architectures. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: FTTransformerRegressor + :members: + :undoc-members: + +.. autoclass:: FTTransformerClassifier + :members: + :undoc-members: + +.. autoclass:: FTTransformerLSS + :members: + :undoc-members: diff --git a/docs/api/models/index.rst b/docs/api/models/index.rst index 84897da..864dc51 100644 --- a/docs/api/models/index.rst +++ b/docs/api/models/index.rst @@ -163,5 +163,26 @@ Modules Description .. toctree:: :maxdepth: 1 + :caption: Stable Models + + mlp + resnet + fttransformer + tabtransformer + saint + tabm + tabr + node + ndtf + tabularrnn + mambular + mambatab + mambattention + enode + autoint + +.. toctree:: + :maxdepth: 1 + :caption: Full API Reference Models diff --git a/docs/api/models/mambatab.rst b/docs/api/models/mambatab.rst new file mode 100644 index 0000000..f2be816 --- /dev/null +++ b/docs/api/models/mambatab.rst @@ -0,0 +1,38 @@ +MambaTab +======== + +A lightweight Mamba-based architecture that applies a single Mamba SSM block to +a joint representation of all input features. Rather than tokenising each +feature individually, MambaTab concatenates all feature embeddings into one +vector, making it the most computationally efficient model in the Mamba family. + +When to Use +----------- + +Efficiency-focused scenarios where a fast Mamba-based baseline is needed before +scaling to the more expressive :doc:`mambular` architecture. Useful when +training or inference speed is a hard constraint. + +Limitations +----------- + +- The joint input representation loses per-feature granularity compared to + token-level models (FTTransformer, Mambular). +- Less expressive than multi-layer Mambular for complex datasets. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MambaTabRegressor + :members: + :undoc-members: + +.. autoclass:: MambaTabClassifier + :members: + :undoc-members: + +.. autoclass:: MambaTabLSS + :members: + :undoc-members: diff --git a/docs/api/models/mambattention.rst b/docs/api/models/mambattention.rst new file mode 100644 index 0000000..d14fc88 --- /dev/null +++ b/docs/api/models/mambattention.rst @@ -0,0 +1,39 @@ +MambAttention +============= + +Hybrid Mamba + Attention architecture. MambAttention interleaves Mamba SSM +layers with multi-head self-attention layers, allowing the model to capture both +local sequential patterns (via Mamba's linear-time recurrence) and global +dependencies across all features simultaneously (via attention). + +When to Use +----------- + +When you need the memory efficiency of Mamba for local patterns and the +expressiveness of attention for global feature interactions. A natural upgrade +from either :doc:`mambular` or :doc:`fttransformer` when neither alone is +sufficient. + +Limitations +----------- + +- More hyperparameters than either Mambular or FTTransformer alone. +- Higher compute and memory cost than a pure Mamba or pure attention model. +- Fewer community benchmarks available; expect more tuning effort. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MambAttentionRegressor + :members: + :undoc-members: + +.. autoclass:: MambAttentionClassifier + :members: + :undoc-members: + +.. autoclass:: MambAttentionLSS + :members: + :undoc-members: diff --git a/docs/api/models/mambular.rst b/docs/api/models/mambular.rst new file mode 100644 index 0000000..d0fa7b3 --- /dev/null +++ b/docs/api/models/mambular.rst @@ -0,0 +1,40 @@ +Mambular +======== + +Sequential Mamba Structured State Space Model (SSM) blocks adapted for tabular +data. Each feature is embedded as a token and the resulting sequence is +processed by stacked Mamba layers, which use efficient linear-time recurrence +rather than quadratic attention. This allows Mambular to scale to longer feature +sequences while keeping memory costs linear. + +When to Use +----------- + +Ordered feature sets or large-scale datasets where Transformer memory costs are +prohibitive. Particularly compelling as an attention-free alternative when the +feature sequence has inherent order (e.g., time-step columns, sensor channels). + +Limitations +----------- + +- Newer architecture with less empirical validation than MLP/ResNet baselines. +- May require more epochs to converge compared to Transformer-based models. +- Performance can be sensitive to the Mamba-specific hyperparameters + (``d_state``, ``expand_factor``). + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MambularRegressor + :members: + :undoc-members: + +.. autoclass:: MambularClassifier + :members: + :undoc-members: + +.. autoclass:: MambularLSS + :members: + :undoc-members: diff --git a/docs/api/models/mlp.rst b/docs/api/models/mlp.rst new file mode 100644 index 0000000..5388b52 --- /dev/null +++ b/docs/api/models/mlp.rst @@ -0,0 +1,39 @@ +MLP +=== + +A fully-connected feedforward network with configurable depth and width. The +simplest and fastest deep learning baseline for tabular data. Each hidden layer +applies a linear transformation followed by an activation function and optional +dropout. + +When to Use +----------- + +Start here before trying more complex architectures. Works well on most datasets +as a fast, low-cost baseline. Ideal for smaller datasets or when compute budget +is limited. Also useful as a sanity-check model to verify the data pipeline. + +Limitations +----------- + +- Cannot model complex feature interactions without explicit feature engineering. +- May underfit on datasets with strong structural or sequential patterns. +- Performance plateaus with depth due to vanishing gradients (use ResNet if this + is a concern). + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: MLPRegressor + :members: + :undoc-members: + +.. autoclass:: MLPClassifier + :members: + :undoc-members: + +.. autoclass:: MLPLSS + :members: + :undoc-members: diff --git a/docs/api/models/ndtf.rst b/docs/api/models/ndtf.rst new file mode 100644 index 0000000..7d45dbc --- /dev/null +++ b/docs/api/models/ndtf.rst @@ -0,0 +1,40 @@ +NDTF +==== + +Neural Decision Tree Forest. An ensemble of differentiable soft decision trees +where routing probabilities at each node are learned via sigmoid activations. +A path-probability regularisation term (controlled by ``lamda``) penalises +over-confident or imbalanced routing, encouraging diverse tree usage across the +forest. + +When to Use +----------- + +When interpretability through decision paths is desirable alongside neural +gradient optimisation. Useful as an alternative to NODE when a forest structure +(multiple independent trees) is preferred over oblivious ensembles. + +Limitations +----------- + +- Sensitive to the ``temperature`` and ``lamda`` regularisation hyperparameters. +- Can underfit with too few trees (``n_ensembles``) or overfit with too many. +- Less effective for very high-dimensional data where feature selection at each + split becomes noisy. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: NDTFRegressor + :members: + :undoc-members: + +.. autoclass:: NDTFClassifier + :members: + :undoc-members: + +.. autoclass:: NDTFLSS + :members: + :undoc-members: diff --git a/docs/api/models/node.rst b/docs/api/models/node.rst new file mode 100644 index 0000000..076c22f --- /dev/null +++ b/docs/api/models/node.rst @@ -0,0 +1,40 @@ +NODE +==== + +Neural Oblivious Decision Ensembles. Each NODE layer is a differentiable +ensemble of oblivious decision trees — trees where the same splitting feature +and threshold is used at every node of a given depth. The trees are made +end-to-end differentiable via entmax transformations, allowing gradient-based +training. + +When to Use +----------- + +When you want the inductive bias of gradient-boosted decision trees inside a +neural framework. Often competitive with gradient boosting on structured tabular +benchmarks while remaining composable as a standard PyTorch layer. + +Limitations +----------- + +- High memory consumption, especially at larger tree depths. +- Slower to train than MLP-based models. +- Sensitive to the ``depth`` hyperparameter; too shallow loses expressiveness, + too deep causes memory and overfitting issues. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: NODERegressor + :members: + :undoc-members: + +.. autoclass:: NODEClassifier + :members: + :undoc-members: + +.. autoclass:: NODELSS + :members: + :undoc-members: diff --git a/docs/api/models/resnet.rst b/docs/api/models/resnet.rst new file mode 100644 index 0000000..37b3ad9 --- /dev/null +++ b/docs/api/models/resnet.rst @@ -0,0 +1,38 @@ +ResNet +====== + +A deep residual network adapted for tabular data. Skip connections let gradients +flow through deeper stacks without vanishing, enabling more representational +capacity than a plain MLP at the same depth. Each residual block applies two +linear layers with batch normalisation and a skip connection. + +When to Use +----------- + +Choose ResNet when a plain MLP fails to converge well or produces unstable +training curves, or when you need more depth without gradient issues. A good +second step after benchmarking MLP. + +Limitations +----------- + +- More hyperparameters than plain MLP (block size, number of blocks). +- Skip connections add memory overhead. +- May not outperform MLP on small datasets where depth is not beneficial. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: ResNetRegressor + :members: + :undoc-members: + +.. autoclass:: ResNetClassifier + :members: + :undoc-members: + +.. autoclass:: ResNetLSS + :members: + :undoc-members: diff --git a/docs/api/models/saint.rst b/docs/api/models/saint.rst new file mode 100644 index 0000000..0863377 --- /dev/null +++ b/docs/api/models/saint.rst @@ -0,0 +1,40 @@ +SAINT +===== + +Self-Attention and Intersample Attention Transformer. SAINT augments the +standard column-wise attention of a Transformer with a second attention +mechanism that operates across rows — allowing each sample to attend to other +samples in the batch. This enables the model to leverage inter-sample +relationships during training. + +When to Use +----------- + +When inter-sample relationships are informative, such as in recommendation or +retrieval tasks. Reported strong performance on semi-supervised tabular +benchmarks. Consider SAINT when FTTransformer leaves significant headroom and +more expressive attention is warranted. + +Limitations +----------- + +- Quadratic memory complexity in batch size due to intersample attention. +- Significantly slower than single-sample Transformer models on large batches. +- Gains over simpler models are dataset-dependent; not always worth the extra cost. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: SAINTRegressor + :members: + :undoc-members: + +.. autoclass:: SAINTClassifier + :members: + :undoc-members: + +.. autoclass:: SAINTLSS + :members: + :undoc-members: diff --git a/docs/api/models/tabm.rst b/docs/api/models/tabm.rst new file mode 100644 index 0000000..db506d3 --- /dev/null +++ b/docs/api/models/tabm.rst @@ -0,0 +1,38 @@ +TabM +==== + +Batch ensembling applied to an MLP. TabM trains multiple ensemble members that +share most of their weights, with only lightweight per-member scaling factors +making each head distinct. This delivers ensemble-level accuracy at near +single-model memory and compute cost. + +When to Use +----------- + +When you want ensembling diversity without the cost of training multiple +independent models. A strong regularised baseline that often outperforms plain +MLP with minimal extra overhead. + +Limitations +----------- + +- Slightly higher memory footprint than a plain MLP due to the per-member factors. +- The number of ensemble members is an additional hyperparameter to tune. +- Gains diminish beyond a moderate number of members. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabMRegressor + :members: + :undoc-members: + +.. autoclass:: TabMClassifier + :members: + :undoc-members: + +.. autoclass:: TabMLSS + :members: + :undoc-members: diff --git a/docs/api/models/tabr.rst b/docs/api/models/tabr.rst new file mode 100644 index 0000000..30c4eff --- /dev/null +++ b/docs/api/models/tabr.rst @@ -0,0 +1,40 @@ +TabR +==== + +Retrieval-augmented tabular model. At inference time, TabR retrieves the most +similar training examples from a stored memory of embeddings and uses them as +additional context when computing the prediction. This gives the model access to +local neighbourhood information beyond what is encoded in its weights. + +When to Use +----------- + +Datasets where local similarity structure is informative — rows that are similar +in feature space tend to share similar targets. Effective on low-to-medium-size +datasets where a full nearest-neighbour memory can be maintained affordably. + +Limitations +----------- + +- Inference time scales with training set size as the model must search the + memory store. +- Not suitable for very large datasets (>100 k rows) without approximate + nearest-neighbour indexing. +- Requires keeping the training set in memory during inference. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabRRegressor + :members: + :undoc-members: + +.. autoclass:: TabRClassifier + :members: + :undoc-members: + +.. autoclass:: TabRLSS + :members: + :undoc-members: diff --git a/docs/api/models/tabtransformer.rst b/docs/api/models/tabtransformer.rst new file mode 100644 index 0000000..3d4134a --- /dev/null +++ b/docs/api/models/tabtransformer.rst @@ -0,0 +1,39 @@ +TabTransformer +============== + +Transformer for tabular data with a focus on categorical feature embeddings. +Categorical features are embedded and passed through Transformer encoder layers +to capture inter-categorical dependencies, while numerical features bypass the +attention mechanism and are concatenated at the prediction head. + +When to Use +----------- + +Datasets dominated by high-cardinality categorical features where relationships +between categories are informative. Commonly used in click-through-rate +prediction and entity-heavy tabular problems. + +Limitations +----------- + +- Limited benefit for datasets with mostly numerical features. +- Slower than MLP-based models. +- FTTransformer typically outperforms TabTransformer on mixed datasets because + it tokenises all features uniformly. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabTransformerRegressor + :members: + :undoc-members: + +.. autoclass:: TabTransformerClassifier + :members: + :undoc-members: + +.. autoclass:: TabTransformerLSS + :members: + :undoc-members: diff --git a/docs/api/models/tabularrnn.rst b/docs/api/models/tabularrnn.rst new file mode 100644 index 0000000..3e77286 --- /dev/null +++ b/docs/api/models/tabularrnn.rst @@ -0,0 +1,41 @@ +TabulaRNN +========= + +Recurrent neural network for tabular data. TabulaRNN treats the feature vector +as a sequence of tokens and processes it with a recurrent cell. The cell type is +configurable: ``RNN``, ``LSTM``, ``GRU``, ``mLSTM`` (matrix LSTM), or +``sLSTM`` (scalar LSTM from the xLSTM family). This makes it a flexible +sequence model that spans classical to modern recurrent architectures. + +When to Use +----------- + +Best suited for datasets where feature ordering encodes meaningful structure — +for example, temporally ordered measurements stored as columns. Also a viable +alternative to Transformer-based models when memory efficiency is a priority. + +Limitations +----------- + +- Performance is sensitive to feature ordering; shuffling columns can + significantly change results. +- May underperform Transformer architectures on unordered tabular data where + positional bias is irrelevant. +- The mLSTM and sLSTM variants are newer and less empirically validated. + +API Reference +------------- + +.. currentmodule:: deeptab.models + +.. autoclass:: TabulaRNNRegressor + :members: + :undoc-members: + +.. autoclass:: TabulaRNNClassifier + :members: + :undoc-members: + +.. autoclass:: TabulaRNNLSS + :members: + :undoc-members: diff --git a/docs/examples/classification.md b/docs/examples/classification.md index 7c2586b..8812cd4 100644 --- a/docs/examples/classification.md +++ b/docs/examples/classification.md @@ -79,6 +79,30 @@ model.fit(X_train, y_train, max_epochs=50) print(model.evaluate(X_test, y_test)) ``` +## All stable classifiers + +Swap `MambularClassifier` for any class below — no other code changes are needed: + +| Class | Architecture | Notes | +|---|---|---| +| `MLPClassifier` | Feedforward MLP | Fastest baseline | +| `ResNetClassifier` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerClassifier` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerClassifier` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTClassifier` | Self + intersample attention | Good for semi-supervised settings | +| `TabMClassifier` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRClassifier` | Retrieval-augmented | Strong when local similarity matters | +| `NODEClassifier` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFClassifier` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNClassifier` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularClassifier` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabClassifier` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionClassifier` | Mamba + attention hybrid | Local + global patterns | +| `ENODEClassifier` | Extended NODE | NODE with feature embeddings | +| `AutoIntClassifier` | Attention-based interaction | Explicit feature crossing | + +Experimental classifiers (`ModernNCAClassifier`, `TromptClassifier`, `TangosClassifier`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). + ## Next steps - [Key Concepts](../key_concepts) — learn how to tune hyperparameters via config objects. diff --git a/docs/examples/distributional.md b/docs/examples/distributional.md index 31e537a..3bba092 100644 --- a/docs/examples/distributional.md +++ b/docs/examples/distributional.md @@ -76,6 +76,30 @@ model.fit(X_train, y_train, family="normal", max_epochs=50) print(model.evaluate(X_test, y_test)) ``` +## All stable LSS models + +Swap `MambularLSS` for any class below — pass `family=` to `.fit()` to select the output distribution: + +| Class | Architecture | Notes | +|---|---|---| +| `MLPLSS` | Feedforward MLP | Fastest baseline | +| `ResNetLSS` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerLSS` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerLSS` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTLSS` | Self + intersample attention | Good for semi-supervised settings | +| `TabMLSS` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRLSS` | Retrieval-augmented | Strong when local similarity matters | +| `NODELSS` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFLSS` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNLSS` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularLSS` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabLSS` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionLSS` | Mamba + attention hybrid | Local + global patterns | +| `ENODELSS` | Extended NODE | NODE with feature embeddings | +| `AutoIntLSS` | Attention-based interaction | Explicit feature crossing | + +Experimental LSS models (`ModernNCALSS`, `TromptLSS`, `TangosLSS`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). + ## Next steps - [Key Concepts](../key_concepts) — understand the `LSS` task variant and available distribution families. diff --git a/docs/examples/regression.md b/docs/examples/regression.md index b21dea8..4660fa9 100644 --- a/docs/examples/regression.md +++ b/docs/examples/regression.md @@ -77,6 +77,30 @@ model.fit(X_train, y_train, max_epochs=50) print(model.evaluate(X_test, y_test)) ``` +## All stable regressors + +Swap `MambularRegressor` for any class below — no other code changes are needed: + +| Class | Architecture | Notes | +|---|---|---| +| `MLPRegressor` | Feedforward MLP | Fastest baseline | +| `ResNetRegressor` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerRegressor` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerRegressor` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTRegressor` | Self + intersample attention | Good for semi-supervised settings | +| `TabMRegressor` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRRegressor` | Retrieval-augmented | Strong when local similarity matters | +| `NODERegressor` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFRegressor` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNRegressor` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularRegressor` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabRegressor` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionRegressor` | Mamba + attention hybrid | Local + global patterns | +| `ENODERegressor` | Extended NODE | NODE with feature embeddings | +| `AutoIntRegressor` | Attention-based interaction | Explicit feature crossing | + +Experimental regressors (`ModernNCARegressor`, `TromptRegressor`, `TangosRegressor`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). + ## Next steps - [Key Concepts](../key_concepts) — learn how to tune hyperparameters via config objects. From 0ceb06fd67ce5ed241a993854d247bf4f0e3c2ef Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 16:27:46 +0200 Subject: [PATCH 07/22] fix: DefaultTabRConfig export --- deeptab/configs/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deeptab/configs/__init__.py b/deeptab/configs/__init__.py index 63cf376..287e358 100644 --- a/deeptab/configs/__init__.py +++ b/deeptab/configs/__init__.py @@ -12,6 +12,7 @@ from .resnet_config import DefaultResNetConfig from .saint_config import DefaultSAINTConfig from .tabm_config import DefaultTabMConfig +from .tabr_config import DefaultTabRConfig from .tabtransformer_config import DefaultTabTransformerConfig from .tabularnn_config import DefaultTabulaRNNConfig from .tangos_config import DefaultTangosConfig @@ -32,6 +33,7 @@ "DefaultResNetConfig", "DefaultSAINTConfig", "DefaultTabMConfig", + "DefaultTabRConfig", "DefaultTabTransformerConfig", "DefaultTabulaRNNConfig", "DefaultTangosConfig", From 3a616ce04e80fe096fd3b4b93a18cc618233a20a Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 16:28:40 +0200 Subject: [PATCH 08/22] docs: tables created for examples --- docs/examples/classification.md | 34 ++++++++++++++++----------------- docs/examples/distributional.md | 34 ++++++++++++++++----------------- docs/examples/regression.md | 34 ++++++++++++++++----------------- 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/docs/examples/classification.md b/docs/examples/classification.md index 8812cd4..054dea1 100644 --- a/docs/examples/classification.md +++ b/docs/examples/classification.md @@ -83,23 +83,23 @@ print(model.evaluate(X_test, y_test)) Swap `MambularClassifier` for any class below — no other code changes are needed: -| Class | Architecture | Notes | -|---|---|---| -| `MLPClassifier` | Feedforward MLP | Fastest baseline | -| `ResNetClassifier` | Residual MLP | Better than MLP for deeper networks | -| `FTTransformerClassifier` | Feature-Tokenizer Transformer | Strong general-purpose model | -| `TabTransformerClassifier` | Transformer on categorical embeddings | Best for categorical-heavy data | -| `SAINTClassifier` | Self + intersample attention | Good for semi-supervised settings | -| `TabMClassifier` | Batch-ensembling MLP | Ensemble accuracy at low cost | -| `TabRClassifier` | Retrieval-augmented | Strong when local similarity matters | -| `NODEClassifier` | Differentiable decision trees | Gradient-boosting inductive bias | -| `NDTFClassifier` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | -| `TabulaRNNClassifier` | RNN / LSTM / GRU | Use `model_type` to select cell | -| `MambularClassifier` | Stacked Mamba SSM | Efficient sequence model | -| `MambaTabClassifier` | Single Mamba block | Lightest Mamba variant | -| `MambAttentionClassifier` | Mamba + attention hybrid | Local + global patterns | -| `ENODEClassifier` | Extended NODE | NODE with feature embeddings | -| `AutoIntClassifier` | Attention-based interaction | Explicit feature crossing | +| Class | Architecture | Notes | +| -------------------------- | ------------------------------------- | ------------------------------------ | +| `MLPClassifier` | Feedforward MLP | Fastest baseline | +| `ResNetClassifier` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerClassifier` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerClassifier` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTClassifier` | Self + intersample attention | Good for semi-supervised settings | +| `TabMClassifier` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRClassifier` | Retrieval-augmented | Strong when local similarity matters | +| `NODEClassifier` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFClassifier` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNClassifier` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularClassifier` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabClassifier` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionClassifier` | Mamba + attention hybrid | Local + global patterns | +| `ENODEClassifier` | Extended NODE | NODE with feature embeddings | +| `AutoIntClassifier` | Attention-based interaction | Explicit feature crossing | Experimental classifiers (`ModernNCAClassifier`, `TromptClassifier`, `TangosClassifier`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). diff --git a/docs/examples/distributional.md b/docs/examples/distributional.md index 3bba092..75af6e0 100644 --- a/docs/examples/distributional.md +++ b/docs/examples/distributional.md @@ -80,23 +80,23 @@ print(model.evaluate(X_test, y_test)) Swap `MambularLSS` for any class below — pass `family=` to `.fit()` to select the output distribution: -| Class | Architecture | Notes | -|---|---|---| -| `MLPLSS` | Feedforward MLP | Fastest baseline | -| `ResNetLSS` | Residual MLP | Better than MLP for deeper networks | -| `FTTransformerLSS` | Feature-Tokenizer Transformer | Strong general-purpose model | -| `TabTransformerLSS` | Transformer on categorical embeddings | Best for categorical-heavy data | -| `SAINTLSS` | Self + intersample attention | Good for semi-supervised settings | -| `TabMLSS` | Batch-ensembling MLP | Ensemble accuracy at low cost | -| `TabRLSS` | Retrieval-augmented | Strong when local similarity matters | -| `NODELSS` | Differentiable decision trees | Gradient-boosting inductive bias | -| `NDTFLSS` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | -| `TabulaRNNLSS` | RNN / LSTM / GRU | Use `model_type` to select cell | -| `MambularLSS` | Stacked Mamba SSM | Efficient sequence model | -| `MambaTabLSS` | Single Mamba block | Lightest Mamba variant | -| `MambAttentionLSS` | Mamba + attention hybrid | Local + global patterns | -| `ENODELSS` | Extended NODE | NODE with feature embeddings | -| `AutoIntLSS` | Attention-based interaction | Explicit feature crossing | +| Class | Architecture | Notes | +| ------------------- | ------------------------------------- | ------------------------------------ | +| `MLPLSS` | Feedforward MLP | Fastest baseline | +| `ResNetLSS` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerLSS` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerLSS` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTLSS` | Self + intersample attention | Good for semi-supervised settings | +| `TabMLSS` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRLSS` | Retrieval-augmented | Strong when local similarity matters | +| `NODELSS` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFLSS` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNLSS` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularLSS` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabLSS` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionLSS` | Mamba + attention hybrid | Local + global patterns | +| `ENODELSS` | Extended NODE | NODE with feature embeddings | +| `AutoIntLSS` | Attention-based interaction | Explicit feature crossing | Experimental LSS models (`ModernNCALSS`, `TromptLSS`, `TangosLSS`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). diff --git a/docs/examples/regression.md b/docs/examples/regression.md index 4660fa9..48d847e 100644 --- a/docs/examples/regression.md +++ b/docs/examples/regression.md @@ -81,23 +81,23 @@ print(model.evaluate(X_test, y_test)) Swap `MambularRegressor` for any class below — no other code changes are needed: -| Class | Architecture | Notes | -|---|---|---| -| `MLPRegressor` | Feedforward MLP | Fastest baseline | -| `ResNetRegressor` | Residual MLP | Better than MLP for deeper networks | -| `FTTransformerRegressor` | Feature-Tokenizer Transformer | Strong general-purpose model | -| `TabTransformerRegressor` | Transformer on categorical embeddings | Best for categorical-heavy data | -| `SAINTRegressor` | Self + intersample attention | Good for semi-supervised settings | -| `TabMRegressor` | Batch-ensembling MLP | Ensemble accuracy at low cost | -| `TabRRegressor` | Retrieval-augmented | Strong when local similarity matters | -| `NODERegressor` | Differentiable decision trees | Gradient-boosting inductive bias | -| `NDTFRegressor` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | -| `TabulaRNNRegressor` | RNN / LSTM / GRU | Use `model_type` to select cell | -| `MambularRegressor` | Stacked Mamba SSM | Efficient sequence model | -| `MambaTabRegressor` | Single Mamba block | Lightest Mamba variant | -| `MambAttentionRegressor` | Mamba + attention hybrid | Local + global patterns | -| `ENODERegressor` | Extended NODE | NODE with feature embeddings | -| `AutoIntRegressor` | Attention-based interaction | Explicit feature crossing | +| Class | Architecture | Notes | +| ------------------------- | ------------------------------------- | ------------------------------------ | +| `MLPRegressor` | Feedforward MLP | Fastest baseline | +| `ResNetRegressor` | Residual MLP | Better than MLP for deeper networks | +| `FTTransformerRegressor` | Feature-Tokenizer Transformer | Strong general-purpose model | +| `TabTransformerRegressor` | Transformer on categorical embeddings | Best for categorical-heavy data | +| `SAINTRegressor` | Self + intersample attention | Good for semi-supervised settings | +| `TabMRegressor` | Batch-ensembling MLP | Ensemble accuracy at low cost | +| `TabRRegressor` | Retrieval-augmented | Strong when local similarity matters | +| `NODERegressor` | Differentiable decision trees | Gradient-boosting inductive bias | +| `NDTFRegressor` | Neural decision tree forest | Use `n_ensembles` and `max_depth` | +| `TabulaRNNRegressor` | RNN / LSTM / GRU | Use `model_type` to select cell | +| `MambularRegressor` | Stacked Mamba SSM | Efficient sequence model | +| `MambaTabRegressor` | Single Mamba block | Lightest Mamba variant | +| `MambAttentionRegressor` | Mamba + attention hybrid | Local + global patterns | +| `ENODERegressor` | Extended NODE | NODE with feature embeddings | +| `AutoIntRegressor` | Attention-based interaction | Explicit feature crossing | Experimental regressors (`ModernNCARegressor`, `TromptRegressor`, `TangosRegressor`) are available from `deeptab.models.experimental`. See [Experimental models](experimental). From 932ea31b61aa8e6c045219184e622cbc107c5747 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 17:05:41 +0200 Subject: [PATCH 09/22] fix: add delu dependency for TabR --- poetry.lock | 18 +++++++++++++++++- pyproject.toml | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 785cec0..efa5ede 100644 --- a/poetry.lock +++ b/poetry.lock @@ -780,6 +780,22 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "delu" +version = "0.0.26" +description = "Deep Learning Utilities for PyTorch users." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "delu-0.0.26-py3-none-any.whl", hash = "sha256:b074f0ad5664b80d70096ead0bd6c69133a9ff883a0f7b341b7570619c91870b"}, + {file = "delu-0.0.26.tar.gz", hash = "sha256:460fb6e1bb8b8fcc5639337486c3d32752820147be7442f1d52bbf93beb919a4"}, +] + +[package.dependencies] +numpy = ">=1.21,<3" +torch = ">=1.9,<3" + [[package]] name = "distlib" version = "0.3.9" @@ -4888,4 +4904,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "83960ab51d000c28565beef5350b987f673689b24e033856dc43d50bdf4284ab" +content-hash = "5bcf8e13789e6e6d6b90091278c75c9d553920aacc6d55db64c7a8dc455603ff" diff --git a/pyproject.toml b/pyproject.toml index 67c8a2d..140d2b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ einops = "^0.8.0" accelerate = "^1.2.1" scipy = "^1.15.0" pretab = "^0.0.1" +delu = "*" [tool.poetry.group.dev.dependencies] pytest = "^8.1" From fc891b4e12d9a3b8815e8bda261398214ef05cb3 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 17:34:22 +0200 Subject: [PATCH 10/22] fix: added faiss-cpu --- poetry.lock | 31 ++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index efa5ede..99b04ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -894,6 +894,35 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] +[[package]] +name = "faiss-cpu" +version = "1.13.2" +description = "A library for efficient similarity search and clustering of dense vectors." +optional = false +python-versions = "<3.15,>=3.10" +groups = ["main"] +files = [ + {file = "faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_arm64.whl", hash = "sha256:a9064eb34f8f64438dd5b95c8f03a780b1a3f0b99c46eeacb1f0b5d15fc02dc1"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_x86_64.whl", hash = "sha256:c8d097884521e1ecaea6467aeebbf1aa56ee4a36350b48b2ca6b39366565c317"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ee330a284042c2480f2e90450a10378fd95655d62220159b1408f59ee83ebf1"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ab88ee287c25a119213153d033f7dd64c3ccec466ace267395872f554b648cd7"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:85511129b34f890d19c98b82a0cd5ffb27d89d1cec2ee41d2621ee9f9ef8cf3f"}, + {file = "faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b32eb4065bac352b52a9f5ae07223567fab0a976c7d05017c01c45a1c24264f"}, + {file = "faiss_cpu-1.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eb8bf5dd96465d043c22195afbe8276d5197b710704290d9b454144a0ad892ed"}, + {file = "faiss_cpu-1.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:7c5944d7807d58fe7244b6aba06be710ee7ed99343365ed92699349efe979f51"}, + {file = "faiss_cpu-1.13.2-cp311-cp311-win_arm64.whl", hash = "sha256:19508a1badfb36e456c1c8664eeb948349f604db5c7545f277a0126b4a84b080"}, + {file = "faiss_cpu-1.13.2-cp312-cp312-win_amd64.whl", hash = "sha256:b82c01d30430dd7b1fa442001b9099735d1a82f6bb72033acdc9206d5ac66a64"}, + {file = "faiss_cpu-1.13.2-cp312-cp312-win_arm64.whl", hash = "sha256:2c4f696ae76e7c97cbc12311db83aaf1e7f4f7be06a3ffea7e5b0e8ec1fd805b"}, + {file = "faiss_cpu-1.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:cb4b5ee184816a4b099162ac93c0d7f0033d81a88e7c1291d0a9cc41ec348984"}, + {file = "faiss_cpu-1.13.2-cp313-cp313-win_arm64.whl", hash = "sha256:1243967eeb2298791ff7f3683a4abd2100d7e6ec7542ca05c3b75d47a7f621e5"}, + {file = "faiss_cpu-1.13.2-cp314-cp314-win_amd64.whl", hash = "sha256:c8b645e7d56591aa35dc75415bb53a62e4a494dba010e16f4b67daeffd830bd7"}, + {file = "faiss_cpu-1.13.2-cp314-cp314-win_arm64.whl", hash = "sha256:8113a2a80b59fe5653cf66f5c0f18be0a691825601a52a614c30beb1fca9bc7c"}, +] + +[package.dependencies] +numpy = ">=1.25.0,<3.0" +packaging = "*" + [[package]] name = "fastjsonschema" version = "2.21.2" @@ -4904,4 +4933,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "5bcf8e13789e6e6d6b90091278c75c9d553920aacc6d55db64c7a8dc455603ff" +content-hash = "cd3e8982de0d6cb28767f7bfe3aef250c1254eff982f84a740633aa319bb09b0" diff --git a/pyproject.toml b/pyproject.toml index 140d2b4..a98222e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ accelerate = "^1.2.1" scipy = "^1.15.0" pretab = "^0.0.1" delu = "*" +faiss-cpu = "*" [tool.poetry.group.dev.dependencies] pytest = "^8.1" From 1867178d4a8243ca693a559ae898f2b6d6b3d43f Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 23:12:33 +0200 Subject: [PATCH 11/22] fix(ndtf): correct ensemble aggregation for multi-class and LSS outputs --- deeptab/base_models/ndtf.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/deeptab/base_models/ndtf.py b/deeptab/base_models/ndtf.py index 483b19e..e061482 100644 --- a/deeptab/base_models/ndtf.py +++ b/deeptab/base_models/ndtf.py @@ -123,9 +123,9 @@ def forward(self, *data) -> torch.Tensor: tree_input = x[:, : self.input_dimensions[idx]] preds.append(tree(tree_input, return_penalty=False)) - preds = torch.stack(preds, dim=1).squeeze(-1) - - return preds @ self.tree_weights + preds = torch.stack(preds, dim=1) # (batch, n_ensembles, output_dim) + # Weighted sum over ensemble dim: (batch, output_dim, n_ensembles) @ (n_ensembles, 1) + return (preds.transpose(1, 2) @ self.tree_weights).squeeze(-1) def penalty_forward(self, *data) -> torch.Tensor: """Forward pass of the NDTF model. @@ -158,5 +158,6 @@ def penalty_forward(self, *data) -> torch.Tensor: penalty += pen # Stack predictions and calculate mean across trees - preds = torch.stack(preds, dim=1).squeeze(-1) - return preds @ self.tree_weights, self.hparams.penalty_factor * penalty # type: ignore + preds = torch.stack(preds, dim=1) # (batch, n_ensembles, output_dim) + # Weighted sum over ensemble dim: (batch, output_dim, n_ensembles) @ (n_ensembles, 1) + return (preds.transpose(1, 2) @ self.tree_weights).squeeze(-1), self.hparams.penalty_factor * penalty # type: ignore From 45b4b26c7fcb377acbb246601e27a9b6336317d5 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 23:13:03 +0200 Subject: [PATCH 12/22] fix(tabtrans): add dedicated LayerNorm for numerical features instead of reusing encoder norm --- deeptab/base_models/tabtransformer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deeptab/base_models/tabtransformer.py b/deeptab/base_models/tabtransformer.py index 53e7ac2..9446904 100644 --- a/deeptab/base_models/tabtransformer.py +++ b/deeptab/base_models/tabtransformer.py @@ -96,8 +96,11 @@ def __init__( mlp_input_dim = 0 for feature_name, info in num_feature_info.items(): mlp_input_dim += info["dimension"] + num_input_dim = mlp_input_dim # save before adding d_model mlp_input_dim += self.hparams.d_model + self.num_norm = nn.LayerNorm(num_input_dim) + self.tabular_head = MLPhead( input_dim=mlp_input_dim, config=config, @@ -125,13 +128,13 @@ def forward(self, *data): cat_embeddings = self.embedding_layer(*(None, cat_features, emb_features)) num_features = torch.cat(num_features, dim=1) - num_embeddings = self.norm_f(num_features) # type: ignore + num_features = self.num_norm(num_features) x = self.encoder(cat_embeddings) x = self.pool_sequence(x) - x = torch.cat((x, num_embeddings), axis=1) # type: ignore + x = torch.cat((x, num_features), axis=1) # type: ignore preds = self.tabular_head(x) return preds From 97bed3e9a120535069bd917fc6a5bdddb288f0b7 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 23:13:58 +0200 Subject: [PATCH 13/22] feat(models): add save/load, fix get_params and LSS family pickling --- deeptab/models/utils/sklearn_base_lss.py | 149 ++++++++++++++++++++++- deeptab/models/utils/sklearn_parent.py | 140 ++++++++++++++++++++- 2 files changed, 285 insertions(+), 4 deletions(-) diff --git a/deeptab/models/utils/sklearn_base_lss.py b/deeptab/models/utils/sklearn_base_lss.py index de6e8e9..7baa6aa 100644 --- a/deeptab/models/utils/sklearn_base_lss.py +++ b/deeptab/models/utils/sklearn_base_lss.py @@ -38,6 +38,20 @@ StudentTDistribution, ) +DISTRIBUTION_CLASSES = { + "normal": NormalDistribution, + "poisson": PoissonDistribution, + "gamma": GammaDistribution, + "beta": BetaDistribution, + "dirichlet": DirichletDistribution, + "studentt": StudentTDistribution, + "negativebinom": NegativeBinomialDistribution, + "inversegamma": InverseGammaDistribution, + "categorical": CategoricalDistribution, + "quantile": Quantile, + "johnsonsu": JohnsonSuDistribution, +} + class SklearnBaseLSS(BaseEstimator): def __init__(self, model, config, **kwargs): @@ -105,8 +119,8 @@ def get_params(self, deep=True): params = {} params.update(self.config_kwargs) - if deep: - preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} # type: ignore[attr-defined] + if deep and hasattr(self.preprocessor, "get_params"): + preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} params.update(preprocessor_params) return params @@ -381,6 +395,7 @@ def fit( if family in distribution_classes: self.family = distribution_classes[family](**distributional_kwargs) + self.family_name = family else: raise ValueError(f"Unsupported family: {family}") @@ -622,6 +637,136 @@ def encode(self, X, batch_size=64): return encoded_outputs + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Save the fitted model to *path*. + + Parameters + ---------- + path : str + Destination file path (e.g. ``"model.pt"``). + + Raises + ------ + ValueError + If the model has not been fitted yet. + """ + if not getattr(self, "is_fitted_", False): + raise ValueError("Model must be fitted before saving.") + bundle = { + "_class": type(self), + "config": self.config, + "config_kwargs": self.config_kwargs, + "preprocessor": self.preprocessor, + "feature_info": { + "num": self.data_module.num_feature_info, + "cat": self.data_module.cat_feature_info, + "emb": self.data_module.embedding_feature_info, + }, + "batch_size": self.data_module.batch_size, + "regression": self.data_module.regression, + "model_class": type(self.estimator), + "num_classes": self.task_model.num_classes, + "lss": True, + "family": self.family_name, + "optimizer_type": self.optimizer_type, + "optimizer_kwargs": self.optimizer_kwargs, + "lr": self.task_model.lr, + "lr_patience": self.task_model.lr_patience, + "lr_factor": self.task_model.lr_factor, + "weight_decay": self.task_model.weight_decay, + "task_model_state_dict": self.task_model.state_dict(), + } + torch.save(bundle, path) + + @classmethod + def load(cls, path: str): + """Load and return a fitted model from *path*. + + Parameters + ---------- + path : str + Path to a file previously written by :meth:`save`. + + Returns + ------- + estimator + A fully reconstructed, ready-to-predict estimator. + """ + bundle = torch.load(path, weights_only=False) + + obj = bundle["_class"].__new__(bundle["_class"]) + obj.config = bundle["config"] + obj.config_kwargs = bundle["config_kwargs"] + obj.preprocessor = bundle["preprocessor"] + obj.optimizer_type = bundle["optimizer_type"] + obj.optimizer_kwargs = bundle["optimizer_kwargs"] + obj.built = True + obj.is_fitted_ = True + obj.family = DISTRIBUTION_CLASSES[bundle["family"]]() + obj.family_name = bundle["family"] + obj.preprocessor_arg_names = [ + "n_bins", + "feature_preprocessing", + "numerical_preprocessing", + "categorical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "degree", + "scaling_strategy", + "n_knots", + "use_decision_tree_knots", + "knots_strategy", + "spline_implementation", + ] + + obj.data_module = MambularDataModule( + preprocessor=bundle["preprocessor"], + batch_size=bundle["batch_size"], + shuffle=False, + regression=bundle["regression"], + ) + obj.data_module.num_feature_info = bundle["feature_info"]["num"] + obj.data_module.cat_feature_info = bundle["feature_info"]["cat"] + obj.data_module.embedding_feature_info = bundle["feature_info"]["emb"] + + obj.task_model = TaskModel( + model_class=bundle["model_class"], + config=bundle["config"], + feature_information=( + bundle["feature_info"]["num"], + bundle["feature_info"]["cat"], + bundle["feature_info"]["emb"], + ), + num_classes=bundle["num_classes"], + lss=bundle["lss"], + family=obj.family, + optimizer_type=bundle["optimizer_type"], + optimizer_args=bundle["optimizer_kwargs"], + lr=bundle["lr"], + lr_patience=bundle["lr_patience"], + lr_factor=bundle["lr_factor"], + weight_decay=bundle["weight_decay"], + ) + obj.task_model.load_state_dict(bundle["task_model_state_dict"]) + obj.task_model.eval() + obj.estimator = obj.task_model.estimator + + obj.trainer = pl.Trainer( + max_epochs=1, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + + return obj + def optimize_hparams( self, X, diff --git a/deeptab/models/utils/sklearn_parent.py b/deeptab/models/utils/sklearn_parent.py index 709c723..1c15b8e 100644 --- a/deeptab/models/utils/sklearn_parent.py +++ b/deeptab/models/utils/sklearn_parent.py @@ -63,10 +63,10 @@ def get_params(self, deep=True): params = {} params.update(self.config_kwargs) params.update(self.preprocessor_kwargs) - if deep: + if deep and hasattr(self.preprocessor, "get_params"): preprocessor_params = { key: value - for key, value in self.preprocessor.get_params().items() # type: ignore[attr-defined] + for key, value in self.preprocessor.get_params().items() if key in self.preprocessor_arg_names } params.update(preprocessor_params) @@ -478,6 +478,142 @@ def _pretrain( pool_sequence=pool_sequence, ) + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Save the fitted model to *path*. + + The bundle written by this method can be restored with + :meth:`load`. It contains all state required for inference: + the config, the fitted preprocessor, feature metadata, and + the neural-network weights. + + Parameters + ---------- + path : str + Destination file path (e.g. ``"model.pt"``). + + Raises + ------ + ValueError + If the model has not been fitted yet. + """ + if not getattr(self, "is_fitted_", False): + raise ValueError("Model must be fitted before saving.") + bundle = { + "_class": type(self), + "config": self.config, + "config_kwargs": self.config_kwargs, + "preprocessor_kwargs": getattr(self, "preprocessor_kwargs", {}), + "preprocessor": self.preprocessor, + "feature_info": { + "num": self.data_module.num_feature_info, + "cat": self.data_module.cat_feature_info, + "emb": self.data_module.embedding_feature_info, + }, + "batch_size": self.data_module.batch_size, + "regression": self.data_module.regression, + "model_class": type(self.estimator), + "num_classes": self.task_model.num_classes, + "lss": False, + "family": None, + "optimizer_type": self.optimizer_type, + "optimizer_kwargs": self.optimizer_kwargs, + "lr": self.task_model.lr, + "lr_patience": self.task_model.lr_patience, + "lr_factor": self.task_model.lr_factor, + "weight_decay": self.task_model.weight_decay, + "task_model_state_dict": self.task_model.state_dict(), + } + torch.save(bundle, path) + + @classmethod + def load(cls, path: str): + """Load and return a fitted model from *path*. + + Parameters + ---------- + path : str + Path to a file previously written by :meth:`save`. + + Returns + ------- + estimator + A fully reconstructed, ready-to-predict estimator of the + same type that was saved. + """ + bundle = torch.load(path, weights_only=False) + + obj = bundle["_class"].__new__(bundle["_class"]) + obj.config = bundle["config"] + obj.config_kwargs = bundle["config_kwargs"] + obj.preprocessor_kwargs = bundle.get("preprocessor_kwargs", {}) + obj.preprocessor = bundle["preprocessor"] + obj.optimizer_type = bundle["optimizer_type"] + obj.optimizer_kwargs = bundle["optimizer_kwargs"] + obj.built = True + obj.is_fitted_ = True + obj.preprocessor_arg_names = [ + "n_bins", + "feature_preprocessing", + "numerical_preprocessing", + "categorical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "degree", + "scaling_strategy", + "n_knots", + "use_decision_tree_knots", + "knots_strategy", + "spline_implementation", + ] + + obj.data_module = MambularDataModule( + preprocessor=bundle["preprocessor"], + batch_size=bundle["batch_size"], + shuffle=False, + regression=bundle["regression"], + ) + obj.data_module.num_feature_info = bundle["feature_info"]["num"] + obj.data_module.cat_feature_info = bundle["feature_info"]["cat"] + obj.data_module.embedding_feature_info = bundle["feature_info"]["emb"] + + obj.task_model = TaskModel( + model_class=bundle["model_class"], + config=bundle["config"], + feature_information=( + bundle["feature_info"]["num"], + bundle["feature_info"]["cat"], + bundle["feature_info"]["emb"], + ), + num_classes=bundle["num_classes"], + lss=bundle["lss"], + family=bundle["family"], + optimizer_type=bundle["optimizer_type"], + optimizer_args=bundle["optimizer_kwargs"], + lr=bundle["lr"], + lr_patience=bundle["lr_patience"], + lr_factor=bundle["lr_factor"], + weight_decay=bundle["weight_decay"], + ) + obj.task_model.load_state_dict(bundle["task_model_state_dict"]) + obj.task_model.eval() + obj.estimator = obj.task_model.estimator + + obj.trainer = pl.Trainer( + max_epochs=1, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + + return obj + def optimize_hparams( self, X, From a01a2e2b112325d1b0765cd133c37deb74620b11 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Thu, 14 May 2026 23:15:47 +0200 Subject: [PATCH 14/22] test: full 15-model coverage with save/load, config round-trip, and TabR platform skip --- tests/test_models.py | 154 +++++++++++++++++++++++++++++++++++++++- tests/test_save_load.py | 139 ++++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 tests/test_save_load.py diff --git a/tests/test_models.py b/tests/test_models.py index f0fadf7..fdc4474 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,29 +1,70 @@ """ End-to-end behavioural tests for the sklearn-compatible model API. -Tests cover fit → predict → evaluate for a representative subset of models -(MLP, ResNet, FTTransformer, Mambular) across all three task variants -(Classifier, Regressor, LSS). A small synthetic dataset keeps CI fast. +Tests cover fit → predict → evaluate for all 15 stable models across all +three task variants (Classifier, Regressor, LSS). A small synthetic dataset +keeps CI fast. """ +import platform + import numpy as np import pandas as pd import pytest from sklearn.model_selection import train_test_split from deeptab.models import ( + ENODELSS, MLPLSS, + NDTFLSS, + NODELSS, + SAINTLSS, + AutoIntClassifier, + AutoIntLSS, + AutoIntRegressor, + ENODEClassifier, + ENODERegressor, FTTransformerClassifier, FTTransformerLSS, FTTransformerRegressor, + MambaTabClassifier, + MambaTabLSS, + MambaTabRegressor, + MambAttentionClassifier, + MambAttentionLSS, + MambAttentionRegressor, MambularClassifier, MambularLSS, MambularRegressor, MLPClassifier, MLPRegressor, + NDTFClassifier, + NDTFRegressor, + NODEClassifier, + NODERegressor, ResNetClassifier, ResNetLSS, ResNetRegressor, + SAINTClassifier, + SAINTRegressor, + TabMClassifier, + TabMLSS, + TabMRegressor, + TabRClassifier, + TabRLSS, + TabRRegressor, + TabTransformerClassifier, + TabTransformerLSS, + TabTransformerRegressor, + TabulaRNNClassifier, + TabulaRNNLSS, + TabulaRNNRegressor, +) + +_macos_arm64 = platform.system() == "Darwin" and platform.machine() == "arm64" +_skip_tabr = pytest.mark.skipif( + _macos_arm64, + reason="faiss-cpu from PyPI segfaults on macOS arm64; install via conda for TabR support", ) # --------------------------------------------------------------------------- @@ -56,6 +97,29 @@ def regression_data(): return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) +@pytest.fixture(scope="module") +def classification_data_with_cat(): + """Fixture with one categorical column — required by TabTransformer.""" + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y_cont = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + y = pd.qcut(y_cont, q=N_CLASSES, labels=False) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + df["cat_col"] = rng.choice(["A", "B", "C"], size=N_SAMPLES) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + +@pytest.fixture(scope="module") +def regression_data_with_cat(): + """Fixture with one categorical column — required by TabTransformer.""" + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + df["cat_col"] = rng.choice(["A", "B", "C"], size=N_SAMPLES) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + # --------------------------------------------------------------------------- # Classifier tests # --------------------------------------------------------------------------- @@ -65,6 +129,16 @@ def regression_data(): ResNetClassifier, FTTransformerClassifier, MambularClassifier, + TabMClassifier, + pytest.param(TabRClassifier, marks=_skip_tabr), + NODEClassifier, + NDTFClassifier, + SAINTClassifier, + AutoIntClassifier, + MambaTabClassifier, + MambAttentionClassifier, + TabulaRNNClassifier, + ENODEClassifier, ] @@ -115,6 +189,16 @@ def test_classifier_evaluate_returns_dict(cls, classification_data): ResNetRegressor, FTTransformerRegressor, MambularRegressor, + TabMRegressor, + pytest.param(TabRRegressor, marks=_skip_tabr), + NODERegressor, + NDTFRegressor, + SAINTRegressor, + AutoIntRegressor, + MambaTabRegressor, + MambAttentionRegressor, + TabulaRNNRegressor, + ENODERegressor, ] @@ -149,6 +233,16 @@ def test_regressor_evaluate_returns_dict(cls, regression_data): ResNetLSS, FTTransformerLSS, MambularLSS, + TabMLSS, + pytest.param(TabRLSS, marks=_skip_tabr), + NODELSS, + NDTFLSS, + SAINTLSS, + AutoIntLSS, + MambaTabLSS, + MambAttentionLSS, + TabulaRNNLSS, + ENODELSS, ] @@ -173,3 +267,57 @@ def test_lss_evaluate_returns_dict(cls, regression_data): metrics = model.evaluate(X_test, y_test) assert isinstance(metrics, dict), f"{cls.__name__}.evaluate should return a dict" assert len(metrics) > 0, f"{cls.__name__}.evaluate returned an empty dict" + + +# --------------------------------------------------------------------------- +# Config serialisation round-trip (Requirement 5) +# --------------------------------------------------------------------------- + +ALL_ESTIMATOR_CLASSES = ( + CLASSIFIERS + REGRESSORS + LSS_MODELS + [TabTransformerClassifier, TabTransformerRegressor, TabTransformerLSS] +) + + +@pytest.mark.parametrize("cls", ALL_ESTIMATOR_CLASSES) +def test_config_serialisation_roundtrip(cls): + """get_params() → construct new instance → config values survive.""" + model = cls() + params = model.get_params() + + # Constructing a second instance with the same params must not raise. + model2 = cls(**params) + + # All config kwargs must round-trip exactly. + for key, value in model.config_kwargs.items(): + assert getattr(model2.config, key, object()) == value, ( + f"{cls.__name__}: config.{key}={value!r} did not survive get_params round-trip" + ) + + +# --------------------------------------------------------------------------- +# TabTransformer — requires at least one categorical feature +# --------------------------------------------------------------------------- + +TAB_TRANSFORMER_MODELS = [ + (TabTransformerClassifier, "classification"), + (TabTransformerRegressor, "regression"), + (TabTransformerLSS, "lss"), +] + + +@pytest.mark.parametrize("cls,task", TAB_TRANSFORMER_MODELS) +def test_tabtransformer_fit_predict(cls, task, classification_data_with_cat, regression_data_with_cat): + if task == "classification": + X_train, X_test, y_train, _y_test = classification_data_with_cat + else: + X_train, X_test, y_train, _y_test = regression_data_with_cat + + model = cls() + if task == "lss": + model.fit(X_train, y_train, family="normal", **FIT_KWARGS) + else: + model.fit(X_train, y_train, **FIT_KWARGS) + + preds = model.predict(X_test) + assert preds.shape[0] == len(X_test), f"{cls.__name__}.predict returned unexpected shape" + assert np.isfinite(preds).all(), f"{cls.__name__}.predict returned non-finite values" diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 0000000..6b64344 --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,139 @@ +""" +Round-trip save / load tests — Requirement 4. + +For each task type (Regressor, Classifier, LSS) we: + 1. Fit a small model. + 2. Record predictions on a held-out set. + 3. Save the model to a temporary file. + 4. Load it back into a fresh object. + 5. Assert the predictions are bit-for-bit identical. + +We use the lightest available model (MLP) to keep CI fast. +""" + +import tempfile + +import numpy as np +import pandas as pd +import pytest +from sklearn.model_selection import train_test_split + +from deeptab.models import MLPLSS, MLPClassifier, MLPRegressor + +# --------------------------------------------------------------------------- +# Shared dataset parameters +# --------------------------------------------------------------------------- + +N_SAMPLES = 200 +N_FEATURES = 6 +N_CLASSES = 3 +RANDOM_STATE = 7 +FIT_KWARGS = {"max_epochs": 2, "batch_size": 64} + + +@pytest.fixture(scope="module") +def regression_data(): + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + +@pytest.fixture(scope="module") +def classification_data(): + rng = np.random.default_rng(RANDOM_STATE) + X = rng.standard_normal((N_SAMPLES, N_FEATURES)) + y_cont = X @ rng.standard_normal(N_FEATURES) + rng.standard_normal(N_SAMPLES) + y = pd.qcut(y_cont, q=N_CLASSES, labels=False) + df = pd.DataFrame({f"f{i}": X[:, i] for i in range(N_FEATURES)}) + return train_test_split(df, y, test_size=0.2, random_state=RANDOM_STATE) + + +# --------------------------------------------------------------------------- +# Regressor round-trip +# --------------------------------------------------------------------------- + + +def test_regressor_save_load_predictions(regression_data): + X_train, X_test, y_train, _y_test = regression_data + model = MLPRegressor() + model.fit(X_train, y_train, **FIT_KWARGS) + + preds_before = model.predict(X_test) + + with tempfile.NamedTemporaryFile(suffix=".pt") as f: + model.save(f.name) + loaded = MLPRegressor.load(f.name) + + preds_after = loaded.predict(X_test) + + np.testing.assert_array_equal( + preds_before, + preds_after, + err_msg="MLPRegressor predictions changed after save/load round-trip", + ) + + +def test_regressor_save_raises_when_unfitted(): + model = MLPRegressor() + with pytest.raises(ValueError, match="fitted"): + with tempfile.NamedTemporaryFile(suffix=".pt") as f: + model.save(f.name) + + +# --------------------------------------------------------------------------- +# Classifier round-trip +# --------------------------------------------------------------------------- + + +def test_classifier_save_load_predictions(classification_data): + X_train, X_test, y_train, _y_test = classification_data + model = MLPClassifier() + model.fit(X_train, y_train, **FIT_KWARGS) + + preds_before = model.predict(X_test) + proba_before = model.predict_proba(X_test) + + with tempfile.NamedTemporaryFile(suffix=".pt") as f: + model.save(f.name) + loaded = MLPClassifier.load(f.name) + + preds_after = loaded.predict(X_test) + proba_after = loaded.predict_proba(X_test) + + np.testing.assert_array_equal( + preds_before, + preds_after, + err_msg="MLPClassifier.predict changed after save/load round-trip", + ) + np.testing.assert_array_equal( + proba_before, + proba_after, + err_msg="MLPClassifier.predict_proba changed after save/load round-trip", + ) + + +# --------------------------------------------------------------------------- +# LSS round-trip +# --------------------------------------------------------------------------- + + +def test_lss_save_load_predictions(regression_data): + X_train, X_test, y_train, _y_test = regression_data + model = MLPLSS() + model.fit(X_train, y_train, family="normal", **FIT_KWARGS) + + preds_before = model.predict(X_test) + + with tempfile.NamedTemporaryFile(suffix=".pt") as f: + model.save(f.name) + loaded = MLPLSS.load(f.name) + + preds_after = loaded.predict(X_test) + + np.testing.assert_array_equal( + preds_before, + preds_after, + err_msg="MLPLSS predictions changed after save/load round-trip", + ) From 97e0f1e7ea16ddfccf717e96847fd96348695e05 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 00:03:48 +0200 Subject: [PATCH 15/22] fix(lint): drop unused unpacked variables and dead code --- deeptab/arch_utils/layer_utils/batch_ensemble_layer.py | 4 ++-- deeptab/arch_utils/lstm_utils.py | 4 ++-- deeptab/arch_utils/mamba_utils/mamba_arch.py | 2 +- deeptab/arch_utils/transformer_utils.py | 5 ++--- deeptab/base_models/autoint.py | 3 --- deeptab/base_models/tabr.py | 10 ++++------ deeptab/base_models/utils/basemodel.py | 2 +- deeptab/base_models/utils/lightning_wrapper.py | 4 ++-- 8 files changed, 14 insertions(+), 20 deletions(-) diff --git a/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py b/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py index dd957d1..fb4973e 100644 --- a/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py +++ b/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor: """ # Check input shape and expand if necessary if x.dim() == 3: # Case: (B, L, D) - no ensembles - batch_size, seq_len, input_size = x.shape + batch_size, seq_len, _ = x.shape # Shape: (B, L, ensemble_size, D) x = x.unsqueeze(2).expand(-1, -1, self.ensemble_size, -1) elif x.dim() == 4 and x.size(2) == self.ensemble_size: # Case: (B, L, ensemble_size, D) @@ -451,7 +451,7 @@ def forward(self, query, key, value, mask=None): If the ensemble size `E` does not match `self.ensemble_size`. """ - N, S, E, D = query.size() + N, S, E, _ = query.size() if E != self.ensemble_size: raise ValueError("Ensemble size mismatch.") diff --git a/deeptab/arch_utils/lstm_utils.py b/deeptab/arch_utils/lstm_utils.py index 72514eb..b04a0a7 100644 --- a/deeptab/arch_utils/lstm_utils.py +++ b/deeptab/arch_utils/lstm_utils.py @@ -128,7 +128,7 @@ def forward(self, x): """ if x.ndim != 3: raise ValueError("Input tensor must have 3 dimensions (batch, sequence_length, input_size)") - B, N, D = x.shape + B, N, _ = x.shape device = x.device # Initialize states dynamically based on input shape @@ -293,7 +293,7 @@ def forward(self, x): torch.Tensor Output tensor of shape (batch, sequence_length, input_size). """ - B, N, D = x.shape + B, N, _ = x.shape device = x.device # Initialize states dynamically based on input shape diff --git a/deeptab/arch_utils/mamba_utils/mamba_arch.py b/deeptab/arch_utils/mamba_utils/mamba_arch.py index ab9c4b5..826ece5 100644 --- a/deeptab/arch_utils/mamba_utils/mamba_arch.py +++ b/deeptab/arch_utils/mamba_utils/mamba_arch.py @@ -437,7 +437,7 @@ def forward(self, x): if self.bidirectional: xz_bwd = self.in_proj_bwd(x) - x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1) + x_bwd, _ = xz_bwd.chunk(2, dim=-1) x_bwd = x_bwd.transpose(1, 2) x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L] diff --git a/deeptab/arch_utils/transformer_utils.py b/deeptab/arch_utils/transformer_utils.py index d25d2f0..3d5eb12 100644 --- a/deeptab/arch_utils/transformer_utils.py +++ b/deeptab/arch_utils/transformer_utils.py @@ -329,11 +329,10 @@ def forward(self, x, mask: torch.Tensor = None): # type: ignore The output tensor of shape (N, S, E, D). """ if x.dim() == 3: # Case: (B, L, D) - no ensembles - batch_size, seq_len, input_size = x.shape # Shape: (B, L, ensemble_size, D) x = x.unsqueeze(2).expand(-1, -1, self.ensemble_size, -1) elif x.dim() == 4 and x.size(2) == self.ensemble_size: # Case: (B, L, ensemble_size, D) - batch_size, seq_len, ensemble_size, _ = x.shape + _, _, ensemble_size, _ = x.shape if ensemble_size != self.ensemble_size: raise ValueError(f"Input shape {x.shape} is invalid. Expected shape: (B, S, ensemble_size, N)") else: @@ -425,7 +424,7 @@ def forward(self, x): x: Input embeddings of shape (N, J, D), where N = batch size, J = number of features, D = embedding dimension. """ - _, n, d = x.shape + _, n, _ = x.shape for attn1, ff1, attn2, ff2 in self.layers: # type: ignore # Column-wise attention diff --git a/deeptab/base_models/autoint.py b/deeptab/base_models/autoint.py index a8a6b97..fa25996 100644 --- a/deeptab/base_models/autoint.py +++ b/deeptab/base_models/autoint.py @@ -162,9 +162,6 @@ def forward(self, *data): # Apply normalization before attention if prenormalization is enabled x_residual = layer["norm0"](x_residual) # type: ignore[index] - # Retrieve key-value compression layers - key_compression, value_compression = self._get_kv_compressions(layer) - # Multihead Attention x_residual, _ = layer["attention"](x_residual, x_residual, x_residual) # type: ignore[index] diff --git a/deeptab/base_models/tabr.py b/deeptab/base_models/tabr.py index 4f327ca..7d807e8 100644 --- a/deeptab/base_models/tabr.py +++ b/deeptab/base_models/tabr.py @@ -324,7 +324,7 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): candidate_x, candidate_k = self._encode(candidate_x) x, k = self._encode(x) # encoded x and k - batch_size, d_main = k.shape + _, d_main = k.shape device = k.device context_size = self.context_size @@ -338,9 +338,8 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): # Updating the index is much faster than creating a new one. self.search_index.reset() self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code] - distances: Tensor context_idx: Tensor - distances, context_idx = self.search_index.search( # type: ignore[code] + _, context_idx = self.search_index.search( # type: ignore[code] k.to(torch.float32), context_size ) @@ -398,7 +397,7 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): candidate_x, candidate_k = self._encode(candidate_x) x, k = self._encode(x) # encoded x and k - batch_size, d_main = k.shape + _, d_main = k.shape device = k.device context_size = self.context_size @@ -412,9 +411,8 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): # Updating the index is much faster than creating a new one. self.search_index.reset() self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code] - distances: Tensor context_idx: Tensor - distances, context_idx = self.search_index.search( # type: ignore[code] + _, context_idx = self.search_index.search( # type: ignore[code] k.to(torch.float32), context_size ) diff --git a/deeptab/base_models/utils/basemodel.py b/deeptab/base_models/utils/basemodel.py index 6837b73..d6d7e37 100644 --- a/deeptab/base_models/utils/basemodel.py +++ b/deeptab/base_models/utils/basemodel.py @@ -193,7 +193,7 @@ def pool_sequence(self, out): return out[:, 0, :] elif self.hparams.pooling_method == "learned_flatten": # Flatten sequence and apply a learned linear layer - batch_size, seq_len, hidden_size = out.shape + batch_size, _, _ = out.shape # Shape: (batch_size, seq_len * hidden_size) out = out.reshape(batch_size, -1) # Shape: (batch_size, hidden_size) diff --git a/deeptab/base_models/utils/lightning_wrapper.py b/deeptab/base_models/utils/lightning_wrapper.py index ae6125d..6511e82 100644 --- a/deeptab/base_models/utils/lightning_wrapper.py +++ b/deeptab/base_models/utils/lightning_wrapper.py @@ -176,7 +176,7 @@ def compute_loss(self, predictions, y_true): if getattr(self.estimator, "returns_ensemble", False): # Ensemble case if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3: # Classification case with ensemble: predictions (N, E, k), y_true (N,) - N, E, k = predictions.shape + _, E, _ = predictions.shape loss = 0.0 for ensemble_member in range(E): loss += self.loss_fct( @@ -595,7 +595,7 @@ def contrastive_loss(self, embeddings, knn_indices, temperature=0.1): Tensor Contrastive loss value. """ - N, S, D = embeddings.shape # Batch size, sequence length, embedding dim + _, S, D = embeddings.shape # Batch size, sequence length, embedding dim k_neighbors = knn_indices.shape[1] # Number of neighbors # Normalize embeddings From f666805ff12b32c9eef339dde93d9ce2bd04d05f Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 00:04:26 +0200 Subject: [PATCH 16/22] fix(types): guard task_model None, use getattr for get_params, annotate test kwargs --- deeptab/models/utils/sklearn_base_lss.py | 10 +++++++--- deeptab/models/utils/sklearn_parent.py | 16 +++++++++------- tests/test_save_load.py | 3 ++- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/deeptab/models/utils/sklearn_base_lss.py b/deeptab/models/utils/sklearn_base_lss.py index 7baa6aa..59b08f9 100644 --- a/deeptab/models/utils/sklearn_base_lss.py +++ b/deeptab/models/utils/sklearn_base_lss.py @@ -119,9 +119,11 @@ def get_params(self, deep=True): params = {} params.update(self.config_kwargs) - if deep and hasattr(self.preprocessor, "get_params"): - preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} - params.update(preprocessor_params) + if deep: + get_params_fn = getattr(self.preprocessor, "get_params", None) + if get_params_fn is not None: + preprocessor_params = {"prepro__" + key: value for key, value in get_params_fn().items()} + params.update(preprocessor_params) return params @@ -656,6 +658,8 @@ def save(self, path: str) -> None: """ if not getattr(self, "is_fitted_", False): raise ValueError("Model must be fitted before saving.") + if self.task_model is None: + raise RuntimeError("task_model is unexpectedly None after fitting.") bundle = { "_class": type(self), "config": self.config, diff --git a/deeptab/models/utils/sklearn_parent.py b/deeptab/models/utils/sklearn_parent.py index 1c15b8e..f738cac 100644 --- a/deeptab/models/utils/sklearn_parent.py +++ b/deeptab/models/utils/sklearn_parent.py @@ -63,13 +63,13 @@ def get_params(self, deep=True): params = {} params.update(self.config_kwargs) params.update(self.preprocessor_kwargs) - if deep and hasattr(self.preprocessor, "get_params"): - preprocessor_params = { - key: value - for key, value in self.preprocessor.get_params().items() - if key in self.preprocessor_arg_names - } - params.update(preprocessor_params) + if deep: + get_params_fn = getattr(self.preprocessor, "get_params", None) + if get_params_fn is not None: + preprocessor_params = { + key: value for key, value in get_params_fn().items() if key in self.preprocessor_arg_names + } + params.update(preprocessor_params) return params def set_params(self, **parameters): @@ -502,6 +502,8 @@ def save(self, path: str) -> None: """ if not getattr(self, "is_fitted_", False): raise ValueError("Model must be fitted before saving.") + if self.task_model is None: + raise RuntimeError("task_model is unexpectedly None after fitting.") bundle = { "_class": type(self), "config": self.config, diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 6b64344..86a0f2a 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,6 +12,7 @@ """ import tempfile +from typing import Any import numpy as np import pandas as pd @@ -28,7 +29,7 @@ N_FEATURES = 6 N_CLASSES = 3 RANDOM_STATE = 7 -FIT_KWARGS = {"max_epochs": 2, "batch_size": 64} +FIT_KWARGS: dict[str, Any] = {"max_epochs": 2, "batch_size": 64} @pytest.fixture(scope="module") From 0b29e4321d416680de4bdb3934d33906dd358795 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 00:45:23 +0200 Subject: [PATCH 17/22] fix(tabr): use regression label encoder for LSS by forwarding lss flag TabR used num_classes > 1 to select nn.Embedding for label encoding, mistaking LSS distribution param count (e.g. 2 for Normal) for a class count. This caused IndexError when continuous candidate_y targets were cast to .long() indices. Fix: add lss: bool = False to TabR.__init__, use nn.Linear label encoder when lss=True regardless of num_classes, and guard the three forward-pass classification branches with 'and not self.hparams.lss'. TaskModel already receives lss=True from SklearnBaseLSS; now forwards it to the model class so TabR can observe it. --- deeptab/base_models/tabr.py | 13 +++++++------ deeptab/base_models/utils/lightning_wrapper.py | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/deeptab/base_models/tabr.py b/deeptab/base_models/tabr.py index 7d807e8..06a79e8 100644 --- a/deeptab/base_models/tabr.py +++ b/deeptab/base_models/tabr.py @@ -21,10 +21,11 @@ def __init__( self, feature_information: tuple, num_classes=1, + lss: bool = False, config: DefaultTabRConfig = DefaultTabRConfig(), # noqa: B008 **kwargs, ): - super().__init__(config=config, **kwargs) + super().__init__(config=config, lss=lss, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) # lazy import @@ -94,7 +95,7 @@ def make_block(prenorm: bool) -> nn.Sequential: delu = TabR.delu self.label_encoder = ( nn.Linear(1, d_main) - if num_classes == 1 + if num_classes == 1 or lss else nn.Sequential( nn.Embedding(num_classes, d_main), # gives depreciation warning @@ -278,9 +279,9 @@ def train_with_candidates(self, *data, targets, candidate_x, candidate_y): probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - if self.hparams.num_classes > 1: # for classification + if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) - else: # for regression + else: # for regression or LSS context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] @@ -352,7 +353,7 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - if self.hparams.num_classes > 1: # for classification + if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) @@ -425,7 +426,7 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - if self.hparams.num_classes > 1: # for classification + if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) diff --git a/deeptab/base_models/utils/lightning_wrapper.py b/deeptab/base_models/utils/lightning_wrapper.py index 6511e82..7e58239 100644 --- a/deeptab/base_models/utils/lightning_wrapper.py +++ b/deeptab/base_models/utils/lightning_wrapper.py @@ -93,6 +93,7 @@ def __init__( config=config, feature_information=feature_information, num_classes=output_dim, + lss=lss, **kwargs, ) From 971a2fb1764dcd2c9b0566d50cdd12e1df7470f2 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 01:02:33 +0200 Subject: [PATCH 18/22] fix(tabr): cast candidate_y to float in regression/LSS label encoder path --- deeptab/base_models/tabr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeptab/base_models/tabr.py b/deeptab/base_models/tabr.py index 06a79e8..187ff06 100644 --- a/deeptab/base_models/tabr.py +++ b/deeptab/base_models/tabr.py @@ -282,7 +282,7 @@ def train_with_candidates(self, *data, targets, candidate_x, candidate_y): if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression or LSS - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] @@ -356,7 +356,7 @@ def validate_with_candidates(self, *data, candidate_x, candidate_y): if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] @@ -429,7 +429,7 @@ def predict_with_candidates(self, *data, candidate_x, candidate_y): if self.hparams.num_classes > 1 and not self.hparams.lss: # for classification context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long()) else: # for regression - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) if len(context_y_emb.shape) == 4: context_y_emb = context_y_emb[:, :, 0, :] From 9180fb257c8dc1e33c7737641048c9e773e13e21 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 01:04:15 +0200 Subject: [PATCH 19/22] fix: suppress family hparam warning --- deeptab/base_models/utils/lightning_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptab/base_models/utils/lightning_wrapper.py b/deeptab/base_models/utils/lightning_wrapper.py index 7e58239..9a9c390 100644 --- a/deeptab/base_models/utils/lightning_wrapper.py +++ b/deeptab/base_models/utils/lightning_wrapper.py @@ -77,7 +77,7 @@ def __init__( else: self.loss_fct = nn.MSELoss() - self.save_hyperparameters(ignore=["model_class", "loss_fn"]) + self.save_hyperparameters(ignore=["model_class", "loss_fn", "family"]) self.lr = self.hparams.get("lr", config.lr) self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) From a0f70f40888068daade29dbe305f57ac75d12f52 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 01:13:27 +0200 Subject: [PATCH 20/22] chore: suppress tool warnings --- pyproject.toml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a98222e..a611280 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,18 @@ package = "https://pypi.org/project/deeptab/" # test configuration [tool.pytest.ini_options] pythonpath = ["."] +filterwarnings = [ + # Lightning trainer noise (dataloader workers, log interval, checkpoint dir, tensorboard) + "ignore::UserWarning:lightning", + # PyTorch transformer / RNN config warnings + "ignore::UserWarning:torch", + # properscoring invalid escape sequences + "ignore::DeprecationWarning:properscoring", + # faiss SwigPy builtin type warnings (emitted during import) + "ignore::DeprecationWarning:importlib", + # delu.nn.Lambda custom-function deprecation + "ignore::DeprecationWarning:delu", +] # code quality tools [tool.pyright] From f1fc937a19da14ce7c20ac3f8403485ea124b514 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 01:23:41 +0200 Subject: [PATCH 21/22] fix: test case fixed for windows --- tests/test_save_load.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 86a0f2a..07db1df 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -11,6 +11,7 @@ We use the lightest available model (MLP) to keep CI fast. """ +import os import tempfile from typing import Any @@ -63,9 +64,13 @@ def test_regressor_save_load_predictions(regression_data): preds_before = model.predict(X_test) - with tempfile.NamedTemporaryFile(suffix=".pt") as f: - model.save(f.name) - loaded = MLPRegressor.load(f.name) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tmp_path = f.name + try: + model.save(tmp_path) + loaded = MLPRegressor.load(tmp_path) + finally: + os.unlink(tmp_path) preds_after = loaded.predict(X_test) @@ -96,9 +101,13 @@ def test_classifier_save_load_predictions(classification_data): preds_before = model.predict(X_test) proba_before = model.predict_proba(X_test) - with tempfile.NamedTemporaryFile(suffix=".pt") as f: - model.save(f.name) - loaded = MLPClassifier.load(f.name) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tmp_path = f.name + try: + model.save(tmp_path) + loaded = MLPClassifier.load(tmp_path) + finally: + os.unlink(tmp_path) preds_after = loaded.predict(X_test) proba_after = loaded.predict_proba(X_test) @@ -127,9 +136,13 @@ def test_lss_save_load_predictions(regression_data): preds_before = model.predict(X_test) - with tempfile.NamedTemporaryFile(suffix=".pt") as f: - model.save(f.name) - loaded = MLPLSS.load(f.name) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tmp_path = f.name + try: + model.save(tmp_path) + loaded = MLPLSS.load(tmp_path) + finally: + os.unlink(tmp_path) preds_after = loaded.predict(X_test) From b081d4b8cda3ee4a52db638a8df881f6fe82d34a Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 15 May 2026 10:30:57 +0200 Subject: [PATCH 22/22] fix: duplicate entry in docs for model classes --- docs/api/models/autoint.rst | 3 +++ docs/api/models/enode.rst | 3 +++ docs/api/models/fttransformer.rst | 3 +++ docs/api/models/mambatab.rst | 3 +++ docs/api/models/mambattention.rst | 3 +++ docs/api/models/mambular.rst | 3 +++ docs/api/models/mlp.rst | 3 +++ docs/api/models/ndtf.rst | 3 +++ docs/api/models/node.rst | 3 +++ docs/api/models/resnet.rst | 3 +++ docs/api/models/saint.rst | 3 +++ docs/api/models/tabm.rst | 3 +++ docs/api/models/tabr.rst | 3 +++ docs/api/models/tabtransformer.rst | 3 +++ docs/api/models/tabularrnn.rst | 3 +++ 15 files changed, 45 insertions(+) diff --git a/docs/api/models/autoint.rst b/docs/api/models/autoint.rst index fa8a4aa..07a5e07 100644 --- a/docs/api/models/autoint.rst +++ b/docs/api/models/autoint.rst @@ -33,11 +33,14 @@ API Reference .. autoclass:: AutoIntRegressor :members: :undoc-members: + :noindex: .. autoclass:: AutoIntClassifier :members: :undoc-members: + :noindex: .. autoclass:: AutoIntLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/enode.rst b/docs/api/models/enode.rst index a1270ba..9d76d46 100644 --- a/docs/api/models/enode.rst +++ b/docs/api/models/enode.rst @@ -30,11 +30,14 @@ API Reference .. autoclass:: ENODERegressor :members: :undoc-members: + :noindex: .. autoclass:: ENODEClassifier :members: :undoc-members: + :noindex: .. autoclass:: ENODELSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/fttransformer.rst b/docs/api/models/fttransformer.rst index f37fc83..460164a 100644 --- a/docs/api/models/fttransformer.rst +++ b/docs/api/models/fttransformer.rst @@ -29,11 +29,14 @@ API Reference .. autoclass:: FTTransformerRegressor :members: :undoc-members: + :noindex: .. autoclass:: FTTransformerClassifier :members: :undoc-members: + :noindex: .. autoclass:: FTTransformerLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/mambatab.rst b/docs/api/models/mambatab.rst index f2be816..9eedf78 100644 --- a/docs/api/models/mambatab.rst +++ b/docs/api/models/mambatab.rst @@ -28,11 +28,14 @@ API Reference .. autoclass:: MambaTabRegressor :members: :undoc-members: + :noindex: .. autoclass:: MambaTabClassifier :members: :undoc-members: + :noindex: .. autoclass:: MambaTabLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/mambattention.rst b/docs/api/models/mambattention.rst index d14fc88..f648f87 100644 --- a/docs/api/models/mambattention.rst +++ b/docs/api/models/mambattention.rst @@ -29,11 +29,14 @@ API Reference .. autoclass:: MambAttentionRegressor :members: :undoc-members: + :noindex: .. autoclass:: MambAttentionClassifier :members: :undoc-members: + :noindex: .. autoclass:: MambAttentionLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/mambular.rst b/docs/api/models/mambular.rst index d0fa7b3..5dd3998 100644 --- a/docs/api/models/mambular.rst +++ b/docs/api/models/mambular.rst @@ -30,11 +30,14 @@ API Reference .. autoclass:: MambularRegressor :members: :undoc-members: + :noindex: .. autoclass:: MambularClassifier :members: :undoc-members: + :noindex: .. autoclass:: MambularLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/mlp.rst b/docs/api/models/mlp.rst index 5388b52..bfa408a 100644 --- a/docs/api/models/mlp.rst +++ b/docs/api/models/mlp.rst @@ -29,11 +29,14 @@ API Reference .. autoclass:: MLPRegressor :members: :undoc-members: + :noindex: .. autoclass:: MLPClassifier :members: :undoc-members: + :noindex: .. autoclass:: MLPLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/ndtf.rst b/docs/api/models/ndtf.rst index 7d45dbc..b8e8a5b 100644 --- a/docs/api/models/ndtf.rst +++ b/docs/api/models/ndtf.rst @@ -30,11 +30,14 @@ API Reference .. autoclass:: NDTFRegressor :members: :undoc-members: + :noindex: .. autoclass:: NDTFClassifier :members: :undoc-members: + :noindex: .. autoclass:: NDTFLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/node.rst b/docs/api/models/node.rst index 076c22f..d011756 100644 --- a/docs/api/models/node.rst +++ b/docs/api/models/node.rst @@ -30,11 +30,14 @@ API Reference .. autoclass:: NODERegressor :members: :undoc-members: + :noindex: .. autoclass:: NODEClassifier :members: :undoc-members: + :noindex: .. autoclass:: NODELSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/resnet.rst b/docs/api/models/resnet.rst index 37b3ad9..67f8398 100644 --- a/docs/api/models/resnet.rst +++ b/docs/api/models/resnet.rst @@ -28,11 +28,14 @@ API Reference .. autoclass:: ResNetRegressor :members: :undoc-members: + :noindex: .. autoclass:: ResNetClassifier :members: :undoc-members: + :noindex: .. autoclass:: ResNetLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/saint.rst b/docs/api/models/saint.rst index 0863377..652e7df 100644 --- a/docs/api/models/saint.rst +++ b/docs/api/models/saint.rst @@ -30,11 +30,14 @@ API Reference .. autoclass:: SAINTRegressor :members: :undoc-members: + :noindex: .. autoclass:: SAINTClassifier :members: :undoc-members: + :noindex: .. autoclass:: SAINTLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/tabm.rst b/docs/api/models/tabm.rst index db506d3..d03a7fd 100644 --- a/docs/api/models/tabm.rst +++ b/docs/api/models/tabm.rst @@ -28,11 +28,14 @@ API Reference .. autoclass:: TabMRegressor :members: :undoc-members: + :noindex: .. autoclass:: TabMClassifier :members: :undoc-members: + :noindex: .. autoclass:: TabMLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/tabr.rst b/docs/api/models/tabr.rst index 30c4eff..779802e 100644 --- a/docs/api/models/tabr.rst +++ b/docs/api/models/tabr.rst @@ -30,11 +30,14 @@ API Reference .. autoclass:: TabRRegressor :members: :undoc-members: + :noindex: .. autoclass:: TabRClassifier :members: :undoc-members: + :noindex: .. autoclass:: TabRLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/tabtransformer.rst b/docs/api/models/tabtransformer.rst index 3d4134a..6bcfdf8 100644 --- a/docs/api/models/tabtransformer.rst +++ b/docs/api/models/tabtransformer.rst @@ -29,11 +29,14 @@ API Reference .. autoclass:: TabTransformerRegressor :members: :undoc-members: + :noindex: .. autoclass:: TabTransformerClassifier :members: :undoc-members: + :noindex: .. autoclass:: TabTransformerLSS :members: :undoc-members: + :noindex: diff --git a/docs/api/models/tabularrnn.rst b/docs/api/models/tabularrnn.rst index 3e77286..f9cc2b5 100644 --- a/docs/api/models/tabularrnn.rst +++ b/docs/api/models/tabularrnn.rst @@ -31,11 +31,14 @@ API Reference .. autoclass:: TabulaRNNRegressor :members: :undoc-members: + :noindex: .. autoclass:: TabulaRNNClassifier :members: :undoc-members: + :noindex: .. autoclass:: TabulaRNNLSS :members: :undoc-members: + :noindex: