Skip to content

Commit 6f4ec71

Browse files
authored
Fix #848: Pass estimation_sample_size parameter to individual trees in UpliftRandomForestClassifier (#850)
1 parent 75bd079 commit 6f4ec71

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

causalml/inference/tree/uplift.pyx

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def group_uniqueCounts_to_arr(np.ndarray[TR_TYPE_t, ndim=1] treatment_idx,
243243
tv = treatment_idx[i]
244244
# assume treatment index is in range
245245
out_arr[2*tv] += 1
246-
# assume y should be either 0 or 1, so this is summing
246+
# assume y should be either 0 or 1, so this is summing
247247
out_arr[2*tv + 1] += y[i]
248248
# adjust the entry at index 2*i to be N(Y = 0, T = i) = N(T = i) - N(Y = 1, T = i)
249249
for i in range(n_class):
@@ -322,7 +322,7 @@ def group_counts_by_divide(
322322
tv = treatment_idx[i]
323323
# assume treatment index is in range
324324
out_arr[2*tv] += 1
325-
# assume y should be either 0 or 1, so this is summing
325+
# assume y should be either 0 or 1, so this is summing
326326
out_arr[2*tv + 1] += y[i]
327327
# adjust the entry at index 2*i to be N(Y = 0, T = i) = N(T = i) - N(Y = 1, T = i)
328328
for i in range(n_class):
@@ -360,9 +360,9 @@ class UpliftTreeClassifier:
360360
n_reg: int, optional (default=100)
361361
The regularization parameter defined in Rzepakowski et al. 2012, the weight (in terms of sample size) of the
362362
parent node influence on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
363-
363+
364364
early_stopping_eval_diff_scale: float, optional (default=1)
365-
If train and valid uplift score diff bigger than
365+
If train and valid uplift score diff bigger than
366366
min(train_uplift_score,valid_uplift_score)/early_stopping_eval_diff_scale, stop.
367367
368368
control_name: string
@@ -404,7 +404,7 @@ class UpliftTreeClassifier:
404404
self.arr_eval_func = self.arr_evaluate_ED
405405
elif evaluationFunction == 'Chi':
406406
self.evaluationFunction = self.evaluate_Chi
407-
self.arr_eval_func = self.arr_evaluate_Chi
407+
self.arr_eval_func = self.arr_evaluate_Chi
408408
elif evaluationFunction == 'DDP':
409409
self.evaluationFunction = self.evaluate_DDP
410410
self.arr_eval_func = self.arr_evaluate_DDP
@@ -465,7 +465,7 @@ class UpliftTreeClassifier:
465465
y_val = (y_val > 0).astype(Y_TYPE) # make sure it is 0 or 1, and is int8
466466
treatment_val = np.asarray(treatment_val)
467467
assert len(y_val) == len(treatment_val), 'Data length must be equal for X_val, treatment_val, and y_val.'
468-
468+
469469
# Get treatment group keys. self.classes_[0] is reserved for the control group.
470470
treatment_groups = sorted([x for x in list(set(treatment)) if x != self.control_name])
471471
self.classes_ = [self.control_name]
@@ -1336,7 +1336,7 @@ class UpliftTreeClassifier:
13361336
np.ndarray[N_TYPE_t, ndim=1] right_node_summary_n):
13371337
'''
13381338
Calculate likelihood ratio test statistic as split evaluation criterion for a given node
1339-
1339+
13401340
NOTE: n_class should be 2.
13411341
13421342
Args
@@ -1365,7 +1365,7 @@ class UpliftTreeClassifier:
13651365
Has type numpy.int32.
13661366
The counts of each of the control
13671367
and treament groups of the right node, i.e. [N(T=i)...]
1368-
1368+
13691369
Returns
13701370
-------
13711371
lrt : Likelihood ratio test statistic
@@ -1422,7 +1422,7 @@ class UpliftTreeClassifier:
14221422
def evaluate_IDDP(nodeSummary):
14231423
'''
14241424
Calculate Delta P as split evaluation criterion for a given node.
1425-
1425+
14261426
Args
14271427
----
14281428
nodeSummary : dictionary
@@ -1444,7 +1444,7 @@ class UpliftTreeClassifier:
14441444
np.ndarray[N_TYPE_t, ndim=1] node_summary_n):
14451445
'''
14461446
Calculate Delta P as split evaluation criterion for a given node.
1447-
1447+
14481448
Args
14491449
----
14501450
node_summary_p : array of shape [n_class]
@@ -1589,7 +1589,7 @@ class UpliftTreeClassifier:
15891589
Normalization factor.
15901590
'''
15911591
cdef N_TYPE_t[::1] cur_summary_n = cur_node_summary_n
1592-
cdef N_TYPE_t[::1] left_summary_n = left_node_summary_n
1592+
cdef N_TYPE_t[::1] left_summary_n = left_node_summary_n
15931593
cdef int n_class = cur_summary_n.shape[0]
15941594
cdef int i = 0
15951595

@@ -1929,7 +1929,7 @@ class UpliftTreeClassifier:
19291929
cdef np.ndarray[N_TYPE_t, ndim=1] val_left_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
19301930
cdef np.ndarray[P_TYPE_t, ndim=1] val_right_summary_p = np.zeros(self.n_class, dtype = P_TYPE)
19311931
cdef np.ndarray[N_TYPE_t, ndim=1] val_right_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
1932-
1932+
19331933
# dummy
19341934
cdef int has_parent_summary = 0
19351935
if parentNodeSummary_p is None:
@@ -2107,7 +2107,7 @@ class UpliftTreeClassifier:
21072107
for k in range(n_class):
21082108
if (abs(val_left_summary_p[k] - left_summary_p[k]) >
21092109
min(val_left_summary_p[k], left_summary_p[k])/early_stopping_eval_diff_scale or
2110-
abs(val_right_summary_p[k] - right_summary_p[k]) >
2110+
abs(val_right_summary_p[k] - right_summary_p[k]) >
21112111
min(val_right_summary_p[k], right_summary_p[k])/early_stopping_eval_diff_scale):
21122112
early_stopping_flag = True
21132113
break
@@ -2160,13 +2160,13 @@ class UpliftTreeClassifier:
21602160
norm_factor = self.arr_normI(cur_summary_n, left_summary_n, alpha=0.9)
21612161
else:
21622162
norm_factor = 1
2163-
gain = gain / norm_factor
2163+
gain = gain / norm_factor
21642164
if (gain > bestGain and len_X_l > min_samples_leaf and len_X_r > min_samples_leaf):
21652165
bestGain = gain
21662166
bestGainImp = gain_for_imp
21672167
best_col = col
21682168
best_value = value
2169-
2169+
21702170
# after finding the best split col and value
21712171
if best_col is not None:
21722172
bestAttribute = (best_col, best_value)
@@ -2364,7 +2364,7 @@ class UpliftRandomForestClassifier:
23642364
child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
23652365
23662366
early_stopping_eval_diff_scale: float, optional (default=1)
2367-
If train and valid uplift score diff bigger than
2367+
If train and valid uplift score diff bigger than
23682368
min(train_uplift_score,valid_uplift_score)/early_stopping_eval_diff_scale, stop.
23692369
23702370
control_name: string
@@ -2427,6 +2427,7 @@ class UpliftRandomForestClassifier:
24272427
self.control_name = control_name
24282428
self.normalization = normalization
24292429
self.honesty = honesty
2430+
self.estimation_sample_size = estimation_sample_size
24302431
self.n_jobs = n_jobs
24312432
self.joblib_prefer = joblib_prefer
24322433

@@ -2477,6 +2478,7 @@ class UpliftRandomForestClassifier:
24772478
control_name=self.control_name,
24782479
normalization=self.normalization,
24792480
honesty=self.honesty,
2481+
estimation_sample_size=self.estimation_sample_size,
24802482
random_state=random_state.randint(MAX_INT))
24812483
for _ in range(self.n_estimators)
24822484
]
@@ -2512,7 +2514,7 @@ class UpliftRandomForestClassifier:
25122514
x_val_bt = X_val[bt_val_index]
25132515
y_val_bt = y_val[bt_val_index]
25142516
treatment_val_bt = treatment_val[bt_val_index]
2515-
2517+
25162518
tree.fit(X=x_train_bt, treatment=treatment_train_bt, y=y_train_bt, X_val=x_val_bt, treatment_val=treatment_val_bt, y_val=y_val_bt)
25172519
return tree
25182520

0 commit comments

Comments
 (0)