Skip to content

Commit 2d112cd

Browse files
jeongyoonleeclaude
andcommitted
Add BaseDRClassifier for binary classification with probabilities
Implements BaseDRClassifier class to address issue #819 by providing a DR-learner that outputs probabilities for binary classification problems, similar to other S/T/X learner classifiers. Key features: - Uses predict_proba() for outcome models to return probabilities - Maintains doubly robust estimation framework - Supports both single learner and separate outcome/effect learners - Includes comprehensive tests with classification data - Follows existing classifier implementation patterns Fixes #819 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 5f5c4fb commit 2d112cd

File tree

3 files changed

+137
-3
lines changed

3 files changed

+137
-3
lines changed

causalml/inference/meta/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
from .xlearner import BaseXLearner, BaseXRegressor, BaseXClassifier
1010
from .rlearner import BaseRLearner, BaseRRegressor, BaseRClassifier, XGBRRegressor
1111
from .tmle import TMLELearner
12-
from .drlearner import BaseDRLearner, BaseDRRegressor, XGBDRRegressor
12+
from .drlearner import BaseDRLearner, BaseDRRegressor, BaseDRClassifier, XGBDRRegressor

causalml/inference/meta/drlearner.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
check_p_conditions,
1414
convert_pd_to_np,
1515
)
16-
from causalml.metrics import regression_metrics
16+
from causalml.metrics import regression_metrics, classification_metrics
1717
from causalml.propensity import compute_propensity_score
1818

1919

@@ -487,6 +487,92 @@ def __init__(
487487
)
488488

489489

490+
class BaseDRClassifier(BaseDRLearner):
491+
"""
492+
A parent class for DR-learner classifier classes.
493+
"""
494+
495+
def __init__(
496+
self,
497+
learner=None,
498+
control_outcome_learner=None,
499+
treatment_outcome_learner=None,
500+
treatment_effect_learner=None,
501+
ate_alpha=0.05,
502+
control_name=0,
503+
):
504+
"""Initialize a DR-learner classifier.
505+
506+
Args:
507+
learner (optional): a model to estimate outcomes and treatment effects in both the control and treatment
508+
groups. Should have a predict_proba() method for outcome models.
509+
control_outcome_learner (optional): a model to estimate outcomes in the control group.
510+
Should have a predict_proba() method.
511+
treatment_outcome_learner (optional): a model to estimate outcomes in the treatment group.
512+
Should have a predict_proba() method.
513+
treatment_effect_learner (optional): a model to estimate treatment effects in the treatment group.
514+
Should be a regressor.
515+
ate_alpha (float, optional): the confidence level alpha of the ATE estimate
516+
control_name (str or int, optional): name of control group
517+
"""
518+
super().__init__(
519+
learner=learner,
520+
control_outcome_learner=control_outcome_learner,
521+
treatment_outcome_learner=treatment_outcome_learner,
522+
treatment_effect_learner=treatment_effect_learner,
523+
ate_alpha=ate_alpha,
524+
control_name=control_name,
525+
)
526+
527+
def predict(
528+
self, X, treatment=None, y=None, p=None, return_components=False, verbose=True
529+
):
530+
"""Predict treatment effects.
531+
532+
Args:
533+
X (np.matrix or np.array or pd.Dataframe): a feature matrix
534+
treatment (np.array or pd.Series, optional): a treatment vector
535+
y (np.array or pd.Series, optional): an outcome vector
536+
verbose (bool, optional): whether to output progress logs
537+
Returns:
538+
(numpy.ndarray): Predictions of treatment effects.
539+
"""
540+
X, treatment, y = convert_pd_to_np(X, treatment, y)
541+
542+
te = np.zeros((X.shape[0], self.t_groups.shape[0]))
543+
yhat_cs = {}
544+
yhat_ts = {}
545+
546+
for i, group in enumerate(self.t_groups):
547+
models_tau = self.models_tau[group]
548+
_te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0)
549+
te[:, i] = np.ravel(_te)
550+
yhat_cs[group] = np.r_[
551+
[model.predict_proba(X)[:, 1] for model in self.models_mu_c]
552+
].mean(axis=0)
553+
yhat_ts[group] = np.r_[
554+
[model.predict_proba(X)[:, 1] for model in self.models_mu_t[group]]
555+
].mean(axis=0)
556+
557+
if (y is not None) and (treatment is not None) and verbose:
558+
mask = (treatment == group) | (treatment == self.control_name)
559+
treatment_filt = treatment[mask]
560+
y_filt = y[mask]
561+
w = (treatment_filt == group).astype(int)
562+
563+
yhat = np.zeros_like(y_filt, dtype=float)
564+
yhat[w == 0] = yhat_cs[group][mask][w == 0]
565+
yhat[w == 1] = yhat_ts[group][mask][w == 1]
566+
567+
logger.info("Error metrics for group {}".format(group))
568+
classification_metrics(y_filt, yhat, w)
569+
570+
if not return_components:
571+
return te
572+
else:
573+
return te, yhat_cs, yhat_ts
574+
575+
490576
class XGBDRRegressor(BaseDRRegressor):
491577
def __init__(self, ate_alpha=0.05, control_name=0, *args, **kwargs):
492578
"""Initialize a DR-learner with two XGBoost models."""

tests/test_meta_learners.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
XGBRRegressor,
3131
)
3232
from causalml.inference.meta import TMLELearner
33-
from causalml.inference.meta import BaseDRLearner
33+
from causalml.inference.meta import BaseDRLearner, BaseDRRegressor, BaseDRClassifier
3434
from causalml.metrics import ape, auuc_score
3535

3636
from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION
@@ -1039,3 +1039,51 @@ def test_BaseDRLearner(generate_regression_data):
10391039
normalize=True,
10401040
)
10411041
assert auuc["cate_p"] > 0.5
1042+
1043+
1044+
1045+
def test_BaseDRClassifier(generate_classification_data):
1046+
np.random.seed(RANDOM_SEED)
1047+
1048+
df, X_names = generate_classification_data()
1049+
1050+
df["treatment_group_key"] = np.where(
1051+
df["treatment_group_key"] == CONTROL_NAME, 0, 1
1052+
)
1053+
1054+
# Extract features and outcome
1055+
y = df[CONVERSION].values
1056+
X = df[X_names].values
1057+
treatment = df["treatment_group_key"].values
1058+
1059+
learner = BaseDRClassifier(
1060+
learner=LogisticRegression(),
1061+
treatment_effect_learner=LinearRegression()
1062+
)
1063+
1064+
# Test fit and predict
1065+
te = learner.fit_predict(X=X, treatment=treatment, y=y)
1066+
1067+
# Check that treatment effects are returned
1068+
assert te.shape[0] == X.shape[0]
1069+
assert te.shape[1] == len(np.unique(treatment[treatment != 0]))
1070+
1071+
# Test with return_components
1072+
te, yhat_cs, yhat_ts = learner.fit_predict(
1073+
X=X, treatment=treatment, y=y, return_components=True
1074+
)
1075+
1076+
# Check that components are returned as probabilities
1077+
for group in learner.t_groups:
1078+
assert np.all((yhat_cs[group] >= 0) & (yhat_cs[group] <= 1))
1079+
assert np.all((yhat_ts[group] >= 0) & (yhat_ts[group] <= 1))
1080+
1081+
# Test separate outcome and effect learners
1082+
learner_separate = BaseDRClassifier(
1083+
control_outcome_learner=LogisticRegression(),
1084+
treatment_outcome_learner=LogisticRegression(),
1085+
treatment_effect_learner=LinearRegression()
1086+
)
1087+
1088+
te_separate = learner_separate.fit_predict(X=X, treatment=treatment, y=y)
1089+
assert te_separate.shape == te.shape

0 commit comments

Comments
 (0)