Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ _build/
*.prof
.venv/
.python-version
uv.lock
uv.lock

# Claude Code artifacts
CLAUDE.md
.claude/
2 changes: 1 addition & 1 deletion causalml/inference/meta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .xlearner import BaseXLearner, BaseXRegressor, BaseXClassifier
from .rlearner import BaseRLearner, BaseRRegressor, BaseRClassifier, XGBRRegressor
from .tmle import TMLELearner
from .drlearner import BaseDRLearner, BaseDRRegressor, XGBDRRegressor
from .drlearner import BaseDRLearner, BaseDRRegressor, BaseDRClassifier, XGBDRRegressor
98 changes: 97 additions & 1 deletion causalml/inference/meta/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
check_p_conditions,
convert_pd_to_np,
)
from causalml.metrics import regression_metrics
from causalml.metrics import regression_metrics, classification_metrics
from causalml.propensity import compute_propensity_score


Expand Down Expand Up @@ -487,6 +487,102 @@ def __init__(
)


class BaseDRClassifier(BaseDRLearner):
"""
A parent class for DR-learner classifier classes.
"""

def __init__(
self,
learner=None,
control_outcome_learner=None,
treatment_outcome_learner=None,
treatment_effect_learner=None,
ate_alpha=0.05,
control_name=0,
):
"""Initialize a DR-learner classifier.

Args:
learner (optional): a model to estimate outcomes and treatment effects in both the control and treatment
groups. Should have a predict_proba() method for outcome models.
control_outcome_learner (optional): a model to estimate outcomes in the control group.
Should have a predict_proba() method.
treatment_outcome_learner (optional): a model to estimate outcomes in the treatment group.
Should have a predict_proba() method.
treatment_effect_learner (optional): a model to estimate treatment effects in the treatment group.
Should be a regressor.
ate_alpha (float, optional): the confidence level alpha of the ATE estimate
control_name (str or int, optional): name of control group
"""
super().__init__(
learner=learner,
control_outcome_learner=control_outcome_learner,
treatment_outcome_learner=treatment_outcome_learner,
treatment_effect_learner=treatment_effect_learner,
ate_alpha=ate_alpha,
control_name=control_name,
)

def predict(
self, X, treatment=None, y=None, p=None, return_components=False, verbose=True
):
"""Predict treatment effects.

Args:
X (np.matrix or np.array or pd.Dataframe): a feature matrix
treatment (np.array or pd.Series, optional): a treatment vector. Used for computing
classification metrics when y is also provided.
y (np.array or pd.Series, optional): an outcome vector. Used for computing
classification metrics when treatment is also provided.
p (np.ndarray or pd.Series or dict, optional): an array of propensity scores of float (0,1) in the
single-treatment case; or, a dictionary of treatment groups that map to propensity vectors of
float (0,1). Currently not used in prediction but kept for API consistency.
return_components (bool, optional): whether to return outcome probabilities for treatment and control
groups separately. Defaults to False.
verbose (bool, optional): whether to output progress logs. Defaults to True.
Returns:
(numpy.ndarray): Predictions of treatment effects.
If return_components is True, also returns:
- dict: Predicted probabilities for the control group (yhat_cs).
- dict: Predicted probabilities for the treatment group (yhat_ts).
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)

te = np.zeros((X.shape[0], self.t_groups.shape[0]))
yhat_cs = {}
yhat_ts = {}

for i, group in enumerate(self.t_groups):
models_tau = self.models_tau[group]
_te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0)
te[:, i] = np.ravel(_te)
yhat_cs[group] = np.r_[
[model.predict_proba(X)[:, 1] for model in self.models_mu_c]
].mean(axis=0)
yhat_ts[group] = np.r_[
[model.predict_proba(X)[:, 1] for model in self.models_mu_t[group]]
].mean(axis=0)

if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
y_filt = y[mask]
w = (treatment_filt == group).astype(int)

yhat = np.zeros_like(y_filt, dtype=float)
yhat[w == 0] = yhat_cs[group][mask][w == 0]
yhat[w == 1] = yhat_ts[group][mask][w == 1]

logger.info("Error metrics for group {}".format(group))
classification_metrics(y_filt, yhat, w)

if not return_components:
return te
else:
return te, yhat_cs, yhat_ts


class XGBDRRegressor(BaseDRRegressor):
def __init__(self, ate_alpha=0.05, control_name=0, *args, **kwargs):
"""Initialize a DR-learner with two XGBoost models."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"lightgbm",
"packaging",
"graphviz",
"black>=25.1.0",
]

[project.optional-dependencies]
Expand Down
48 changes: 48 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from causalml.inference.meta import TMLELearner
from causalml.inference.meta import BaseDRLearner
from causalml.inference.meta import BaseDRRegressor
from causalml.inference.meta import BaseDRClassifier
from causalml.metrics import ape, auuc_score

from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION
Expand Down Expand Up @@ -1039,3 +1041,49 @@ def test_BaseDRLearner(generate_regression_data):
normalize=True,
)
assert auuc["cate_p"] > 0.5


def test_BaseDRClassifier(generate_classification_data):
np.random.seed(RANDOM_SEED)

df, X_names = generate_classification_data()

df["treatment_group_key"] = np.where(
df["treatment_group_key"] == CONTROL_NAME, 0, 1
)

# Extract features and outcome
y = df[CONVERSION].values
X = df[X_names].values
treatment = df["treatment_group_key"].values

learner = BaseDRClassifier(
learner=LogisticRegression(), treatment_effect_learner=LinearRegression()
)

# Test fit and predict
te = learner.fit_predict(X=X, treatment=treatment, y=y)

# Check that treatment effects are returned
assert te.shape[0] == X.shape[0]
assert te.shape[1] == len(np.unique(treatment[treatment != 0]))

# Test with return_components
te, yhat_cs, yhat_ts = learner.fit_predict(
X=X, treatment=treatment, y=y, return_components=True
)

# Check that components are returned as probabilities
for group in learner.t_groups:
assert np.all((yhat_cs[group] >= 0) & (yhat_cs[group] <= 1))
assert np.all((yhat_ts[group] >= 0) & (yhat_ts[group] <= 1))

# Test separate outcome and effect learners
learner_separate = BaseDRClassifier(
control_outcome_learner=LogisticRegression(),
treatment_outcome_learner=LogisticRegression(),
treatment_effect_learner=LinearRegression(),
)

te_separate = learner_separate.fit_predict(X=X, treatment=treatment, y=y)
assert te_separate.shape == te.shape