Skip to content

Commit 2ce28c1

Browse files
authored
Update sensitivity tests with more meta-learners (#759)
* update sensitivity tests with more meta-learners * fix lint errors * reformat with black * fix the type hint error for | in Python3.7 by using typing.Union
1 parent 12747ee commit 2ce28c1

File tree

23 files changed

+170
-136
lines changed

23 files changed

+170
-136
lines changed

causalml/dataset/classification.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import numpy as np
33
import pandas as pd
44
from sklearn.datasets import make_classification
5+
from scipy.interpolate import UnivariateSpline
56
from scipy.optimize import fsolve
6-
from scipy.special import expit
7-
from scipy.special import logit
7+
from scipy.special import expit, logit
88

99

1010
# ------ Define a list of functions for feature transformation
@@ -119,8 +119,9 @@ def _standardize(x):
119119
def _fixed_transformation(fs, x, f_index=0):
120120
"""
121121
Transform and standardize a vector by a transformation function.
122-
If the given index is within the function list f_index < len(fs), then use fs[f_index] as the transformation function
123-
otherwise, randomly choose a function from the function list.
122+
If the given index is within the function list f_index < len(fs), then use fs[f_index] as the transformation
123+
function. Otherwise, randomly choose a function from the function list.
124+
124125
Parameters
125126
----------
126127
fs : list
@@ -160,7 +161,8 @@ def _random_transformation(fs, x):
160161
# @staticmethod
161162
def _softmax(z, p, xb):
162163
"""
163-
Softmax function. This function is used to reversely solve the constant root value in the linear part to make the softmax function output mean to be a given value.
164+
Softmax function. This function is used to reversely solve the constant root value in the linear part to make the
165+
softmax function output mean to be a given value.
164166
165167
Parameters
166168
----------
@@ -201,7 +203,8 @@ def make_uplift_classification_logistic(
201203
n_samples : int, optional (default=1000)
202204
The number of samples to be generated for each treatment group.
203205
treatment_name: list, optional (default = ['control','treatment1','treatment2','treatment3'])
204-
The list of treatment names. The first element must be 'control' as control group, and the rest are treated as treatment groups.
206+
The list of treatment names. The first element must be 'control' as control group, and the rest are treated as
207+
treatment groups.
205208
y_name: string, optional (default = 'conversion')
206209
The name of the outcome variable to be used as a column in the output dataframe.
207210
n_classification_features: int, optional (default = 10)
@@ -218,7 +221,8 @@ def make_uplift_classification_logistic(
218221
n_mix_informative_uplift_dict: dictionary, optional (default: {'treatment1': 1, 'treatment2': 1, 'treatment3': 1})
219222
Number of mix features for each treatment. The mix feature is defined as a linear combination
220223
of a randomly selected informative classification feature and a randomly selected uplift feature.
221-
The mixture is made by a weighted sum (p*feature1 + (1-p)*feature2), where the weight p is drawn from a uniform distribution between 0 and 1.
224+
The mixture is made by a weighted sum (p*feature1 + (1-p)*feature2), where the weight p is drawn from a uniform
225+
distribution between 0 and 1.
222226
delta_uplift_dict: dictionary, optional (default: {'treatment1': .02, 'treatment2': .05, 'treatment3': -.05})
223227
Treatment effect (delta), can be positive or negative.
224228
Dictionary of {treatment_key: delta}.
@@ -227,14 +231,18 @@ def make_uplift_classification_logistic(
227231
random_seed : int, optional (default = 20200101)
228232
The random seed to be used in the data generation process.
229233
feature_association_list : list, optional (default = ['linear','quadratic','cubic','relu','sin','cos'])
230-
List of uplift feature association patterns to the treatment effect. For example, if the feature pattern is 'quadratic', then the treatment effect will increase or decrease quadratically with the feature.
231-
The values in the list must be one of ('linear','quadratic','cubic','relu','sin','cos'). However, the same value can appear multiple times in the list.
234+
List of uplift feature association patterns to the treatment effect. For example, if the feature pattern is
235+
'quadratic', then the treatment effect will increase or decrease quadratically with the feature.
236+
The values in the list must be one of ('linear','quadratic','cubic','relu','sin','cos'). However, the same
237+
value can appear multiple times in the list.
232238
random_select_association : boolean, optional (default = True)
233-
How the feature patterns are selected from the feature_association_list to be applied in the data generation process.
234-
If random_select_association = True, then for every uplift feature, a random feature association pattern is selected from the list.
235-
If random_select_association = False, then the feature association pattern is selected from the list in turns to be applied to each feature one by one.
239+
How the feature patterns are selected from the feature_association_list to be applied in the data generation
240+
process. If random_select_association = True, then for every uplift feature, a random feature association
241+
pattern is selected from the list. If random_select_association = False, then the feature association pattern
242+
is selected from the list in turns to be applied to each feature one by one.
236243
error_std : float, optional (default = 0.05)
237-
Standard deviation to be used in the error term of the logistic regression. The error is drawn from a normal distribution with mean 0 and standard deviation specified in this argument.
244+
Standard deviation to be used in the error term of the logistic regression. The error is drawn from a normal
245+
distribution with mean 0 and standard deviation specified in this argument.
238246
239247
Returns
240248
-------
@@ -273,7 +281,6 @@ def make_uplift_classification_logistic(
273281
f_list.append(feature_association_pattern_dict[fi])
274282

275283
# generate treatment key ------------------------------------------------#
276-
n_all = n * len(treatment_name)
277284
treatment_list = []
278285
for ti in treatment_name:
279286
treatment_list += [ti] * n
@@ -518,14 +525,16 @@ def make_uplift_classification(
518525
delta_uplift_decrease_dict: dictionary, optional (default: {'treatment1': 0., 'treatment2': 0., 'treatment3': 0.})
519526
Negative treatment effect created by the negative uplift features on the base classification label.
520527
Dictionary of {treatment_key: increase_delta}.
521-
n_uplift_increase_mix_informative_dict: dictionary, optional (default: {'treatment1': 1, 'treatment2': 1, 'treatment3': 1})
528+
n_uplift_increase_mix_informative_dict: dictionary, optional
522529
Number of positive mix features for each treatment. The positive mix feature is defined as a linear combination
523530
of a randomly selected informative classification feature and a randomly selected positive uplift feature.
524531
The linear combination is made by two coefficients sampled from a uniform distribution between -1 and 1.
525-
n_uplift_decrease_mix_informative_dict: dictionary, optional (default: {'treatment1': 0, 'treatment2': 0, 'treatment3': 0})
532+
default: {'treatment1': 1, 'treatment2': 1, 'treatment3': 1}
533+
n_uplift_decrease_mix_informative_dict: dictionary, optional
526534
Number of negative mix features for each treatment. The negative mix feature is defined as a linear combination
527535
of a randomly selected informative classification feature and a randomly selected negative uplift feature. The
528536
linear combination is made by two coefficients sampled from a uniform distribution between -1 and 1.
537+
default: {'treatment1': 0, 'treatment2': 0, 'treatment3': 0}
529538
positive_class_proportion: float, optional (default = 0.5)
530539
The proportion of positive label (1) in the control group.
531540
random_seed : int, optional (default = 20190101)

causalml/dataset/synthetic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from __future__ import absolute_import
2-
from __future__ import division
3-
from __future__ import print_function
41
from matplotlib import pyplot as plt
52
import numpy as np
63
import pandas as pd

0 commit comments

Comments
 (0)