Skip to content

Commit 5b17c6c

Browse files
ENH boosting early_termination control (fix #7)
1 parent 1cc4201 commit 5b17c6c

File tree

10 files changed

+75
-14
lines changed

10 files changed

+75
-14
lines changed

imbalanced_ensemble/ensemble/_boost.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ def __init__(self,
7676
sampling_type:str,
7777
learning_rate:float=1.,
7878
algorithm:str='SAMME.R',
79+
early_termination:bool=False,
7980
random_state=None):
8081

8182
self._sampling_type = sampling_type
8283
self.base_sampler = base_sampler
84+
self.early_termination = early_termination
8385

8486
super(ResampleBoostClassifier, self).__init__(
8587
base_estimator=base_estimator,
@@ -319,6 +321,9 @@ def _fit(self, X, y,
319321

320322
self.sampler_kwargs_ = check_type(
321323
sampler_kwargs, 'sampler_kwargs', dict)
324+
325+
early_termination_ = check_type(
326+
self.early_termination, 'early_termination', bool)
322327

323328
# Check that algorithm is supported.
324329
if self.algorithm not in ('SAMME', 'SAMME.R'):
@@ -435,14 +440,14 @@ def _fit(self, X, y,
435440
self._training_log_to_console(iboost, y_resampled)
436441

437442
# Early termination.
438-
if sample_weight is None:
443+
if sample_weight is None and early_termination_:
439444
print (f"Training early-stop at iteration"
440445
f" {iboost+1}/{self.n_estimators}"
441446
f" (sample_weight is None).")
442447
break
443448

444449
# Stop if error is zero.
445-
if estimator_error == 0:
450+
if estimator_error == 0 and early_termination_:
446451
print (f"Training early-stop at iteration"
447452
f" {iboost+1}/{self.n_estimators}"
448453
f" (training error is 0).")
@@ -451,7 +456,7 @@ def _fit(self, X, y,
451456
sample_weight_sum = np.sum(sample_weight)
452457

453458
# Stop if the sum of sample weights has become non-positive.
454-
if sample_weight_sum <= 0:
459+
if sample_weight_sum <= 0 and early_termination_:
455460
print (f"Training early-stop at iteration"
456461
f" {iboost+1}/{self.n_estimators}"
457462
f" (sample_weight_sum <= 0).")
@@ -521,8 +526,11 @@ def __init__(self,
521526
n_estimators:int,
522527
learning_rate:float=1.,
523528
algorithm:str='SAMME.R',
529+
early_termination:bool=False,
524530
random_state=None):
525531

532+
self.early_termination = early_termination
533+
526534
super(ReweightBoostClassifier, self).__init__(
527535
base_estimator=base_estimator,
528536
n_estimators=n_estimators,
@@ -737,6 +745,9 @@ def _fit(self, X, y,
737745
eval_metrics:dict,
738746
train_verbose:bool or int or dict,
739747
):
748+
749+
early_termination_ = check_type(
750+
self.early_termination, 'early_termination', bool)
740751

741752
# Check that algorithm is supported.
742753
if self.algorithm not in ('SAMME', 'SAMME.R'):
@@ -834,14 +845,14 @@ def _fit(self, X, y,
834845
self._training_log_to_console(iboost, y)
835846

836847
# Early termination.
837-
if sample_weight is None:
848+
if sample_weight is None and early_termination_:
838849
print (f"Training early-stop at iteration"
839850
f" {iboost+1}/{self.n_estimators}"
840851
f" (sample_weight is None).")
841852
break
842853

843854
# Stop if error is zero.
844-
if estimator_error == 0:
855+
if estimator_error == 0 and early_termination_:
845856
print (f"Training early-stop at iteration"
846857
f" {iboost+1}/{self.n_estimators}"
847858
f" (training error is 0).")
@@ -850,7 +861,7 @@ def _fit(self, X, y,
850861
sample_weight_sum = np.sum(sample_weight)
851862

852863
# Stop if the sum of sample weights has become non-positive.
853-
if sample_weight_sum <= 0:
864+
if sample_weight_sum <= 0 and early_termination_:
854865
print (f"Training early-stop at iteration"
855866
f" {iboost+1}/{self.n_estimators}"
856867
f" (sample_weight_sum <= 0).")

imbalanced_ensemble/ensemble/compatible/adaboost_compatible.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from ..base import ImbalancedEnsembleClassifierMixin, MAX_INT
1818
from ...utils._validation_data import check_eval_datasets
1919
from ...utils._validation_param import (check_train_verbose,
20-
check_eval_metrics)
20+
check_eval_metrics,
21+
check_type)
2122
from ...utils._validation import _deprecate_positional_args
2223
from ...utils._docstring import (Substitution, FuncSubstitution,
2324
FuncGlossarySubstitution,
@@ -49,6 +50,7 @@
4950

5051

5152
@Substitution(
53+
early_termination=_get_parameter_docstring('early_termination', **_properties),
5254
example=_get_example_docstring(_method_name)
5355
)
5456
class CompatibleAdaBoostClassifier(ImbalancedEnsembleClassifierMixin,
@@ -70,7 +72,6 @@ class CompatibleAdaBoostClassifier(ImbalancedEnsembleClassifierMixin,
7072
Support for sample weighting is required, as well as proper
7173
``classes_`` and ``n_classes_`` attributes. If ``None``, then
7274
the base estimator is :class:`~sklearn.tree.DecisionTreeClassifier`
73-
initialized with `max_depth=1`.
7475
7576
n_estimators : int, default=50
7677
The maximum number of estimators at which boosting is terminated.
@@ -87,6 +88,8 @@ class CompatibleAdaBoostClassifier(ImbalancedEnsembleClassifierMixin,
8788
If 'SAMME' then use the SAMME discrete boosting algorithm.
8889
The SAMME.R algorithm typically converges faster than SAMME,
8990
achieving a lower test error with fewer boosting iterations.
91+
92+
{early_termination}
9093
9194
random_state : int, RandomState instance or None, default=None
9295
Controls the random seed given at each `base_estimator` at each
@@ -148,8 +151,11 @@ def __init__(self,
148151
n_estimators:int=50,
149152
learning_rate:float=1.,
150153
algorithm:str='SAMME.R',
154+
early_termination:bool=False,
151155
random_state=None):
152156

157+
self.early_termination = early_termination
158+
153159
super(CompatibleAdaBoostClassifier, self).__init__(
154160
base_estimator=base_estimator,
155161
n_estimators=n_estimators,
@@ -200,6 +206,9 @@ def fit(self, X, y,
200206
self : object
201207
"""
202208

209+
early_termination_ = check_type(
210+
self.early_termination, 'early_termination', bool)
211+
203212
# Check that algorithm is supported.
204213
if self.algorithm not in ('SAMME', 'SAMME.R'):
205214
raise ValueError("algorithm %s is not supported" % self.algorithm)
@@ -281,14 +290,14 @@ def fit(self, X, y,
281290
self._training_log_to_console(iboost, y)
282291

283292
# Early termination.
284-
if sample_weight is None:
293+
if sample_weight is None and early_termination_:
285294
print (f"Training early-stop at iteration"
286295
f" {iboost+1}/{self.n_estimators}"
287296
f" (sample_weight is None).")
288297
break
289298

290299
# Stop if error is zero.
291-
if estimator_error == 0:
300+
if estimator_error == 0 and early_termination_:
292301
print (f"Training early-stop at iteration"
293302
f" {iboost+1}/{self.n_estimators}"
294303
f" (training error is 0).")
@@ -297,7 +306,7 @@ def fit(self, X, y,
297306
sample_weight_sum = np.sum(sample_weight)
298307

299308
# Stop if the sum of sample weights has become non-positive.
300-
if sample_weight_sum <= 0:
309+
if sample_weight_sum <= 0 and early_termination_:
301310
print (f"Training early-stop at iteration"
302311
f" {iboost+1}/{self.n_estimators}"
303312
f" (sample_weight_sum <= 0).")

imbalanced_ensemble/ensemble/over_sampling/kmeans_smote_boost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
@Substitution(
5252
n_jobs_sampler=_get_parameter_docstring('n_jobs_sampler', **_properties),
53+
early_termination=_get_parameter_docstring('early_termination', **_properties),
5354
random_state=_get_parameter_docstring('random_state', **_properties),
5455
example=_get_example_docstring(_method_name)
5556
)
@@ -108,6 +109,8 @@ class KmeansSMOTEBoostClassifier(ResampleBoostClassifier):
108109
If 'SAMME' then use the SAMME discrete boosting algorithm.
109110
The SAMME.R algorithm typically converges faster than SAMME,
110111
achieving a lower test error with fewer boosting iterations.
112+
113+
{early_termination}
111114
112115
{random_state}
113116
@@ -177,6 +180,7 @@ def __init__(self,
177180
density_exponent="auto",
178181
learning_rate:float=1.,
179182
algorithm:str='SAMME.R',
183+
early_termination:bool=False,
180184
random_state=None):
181185

182186
base_sampler = _sampler_class()
@@ -189,6 +193,7 @@ def __init__(self,
189193
sampling_type=sampling_type,
190194
learning_rate=learning_rate,
191195
algorithm=algorithm,
196+
early_termination=early_termination,
192197
random_state=random_state)
193198

194199
self.__name__ = _method_name

imbalanced_ensemble/ensemble/over_sampling/over_boost.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545

4646
@Substitution(
47-
n_jobs_sampler=_get_parameter_docstring('n_jobs_sampler', **_properties),
47+
early_termination=_get_parameter_docstring('early_termination', **_properties),
4848
random_state=_get_parameter_docstring('random_state', **_properties),
4949
example=_get_example_docstring(_method_name)
5050
)
@@ -81,6 +81,8 @@ class OverBoostClassifier(ResampleBoostClassifier):
8181
If 'SAMME' then use the SAMME discrete boosting algorithm.
8282
The SAMME.R algorithm typically converges faster than SAMME,
8383
achieving a lower test error with fewer boosting iterations.
84+
85+
{early_termination}
8486
8587
{random_state}
8688
@@ -145,6 +147,7 @@ def __init__(self,
145147
*,
146148
learning_rate:float=1.,
147149
algorithm:str='SAMME.R',
150+
early_termination:bool=False,
148151
random_state=None):
149152

150153
base_sampler = _sampler_class()
@@ -157,6 +160,7 @@ def __init__(self,
157160
sampling_type=sampling_type,
158161
learning_rate=learning_rate,
159162
algorithm=algorithm,
163+
early_termination=early_termination,
160164
random_state=random_state)
161165

162166
self.__name__ = _method_name

imbalanced_ensemble/ensemble/over_sampling/smote_boost.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050

5151
@Substitution(
52-
n_jobs_sampler=_get_parameter_docstring('n_jobs_sampler', **_properties),
52+
early_termination=_get_parameter_docstring('early_termination', **_properties),
5353
random_state=_get_parameter_docstring('random_state', **_properties),
5454
example=_get_example_docstring(_method_name)
5555
)
@@ -91,6 +91,8 @@ class balancing by SMOTE over-sampling the sample at each
9191
If 'SAMME' then use the SAMME discrete boosting algorithm.
9292
The SAMME.R algorithm typically converges faster than SAMME,
9393
achieving a lower test error with fewer boosting iterations.
94+
95+
{early_termination}
9496
9597
{random_state}
9698
@@ -156,6 +158,7 @@ def __init__(self,
156158
k_neighbors:int=5,
157159
learning_rate:float=1.,
158160
algorithm:str='SAMME.R',
161+
early_termination:bool=False,
159162
random_state=None):
160163

161164
base_sampler = _sampler_class()
@@ -168,6 +171,7 @@ def __init__(self,
168171
sampling_type=sampling_type,
169172
learning_rate=learning_rate,
170173
algorithm=algorithm,
174+
early_termination=early_termination,
171175
random_state=random_state)
172176

173177
self.__name__ = _method_name

imbalanced_ensemble/ensemble/reweighting/adacost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343

4444
@Substitution(
45+
early_termination=_get_parameter_docstring('early_termination', **_properties),
4546
random_state=_get_parameter_docstring('random_state', **_properties),
4647
example=_get_example_docstring(_method_name)
4748
)
@@ -82,6 +83,8 @@ class AdaCostClassifier(ReweightBoostClassifier):
8283
The SAMME.R algorithm typically converges faster than SAMME,
8384
achieving a lower test error with fewer boosting iterations.
8485
86+
{early_termination}
87+
8588
{random_state}
8689
8790
Attributes
@@ -143,13 +146,15 @@ def __init__(self,
143146
*,
144147
learning_rate:float=1.,
145148
algorithm:str='SAMME.R',
149+
early_termination:bool=False,
146150
random_state=None):
147151

148152
super(AdaCostClassifier, self).__init__(
149153
base_estimator=base_estimator,
150154
n_estimators=n_estimators,
151155
learning_rate=learning_rate,
152156
algorithm=algorithm,
157+
early_termination=early_termination,
153158
random_state=random_state)
154159

155160
self.__name__ = _method_name

imbalanced_ensemble/ensemble/reweighting/adauboost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747

4848
@Substitution(
49+
early_termination=_get_parameter_docstring('early_termination', **_properties),
4950
random_state=_get_parameter_docstring('random_state', **_properties),
5051
example=_get_example_docstring(_method_name)
5152
)
@@ -83,6 +84,8 @@ class AdaUBoostClassifier(ReweightBoostClassifier):
8384
The SAMME.R algorithm typically converges faster than SAMME,
8485
achieving a lower test error with fewer boosting iterations.
8586
87+
{early_termination}
88+
8689
{random_state}
8790
8891
Attributes
@@ -147,6 +150,7 @@ def __init__(self,
147150
*,
148151
learning_rate:float=1.,
149152
algorithm:str='SAMME.R',
153+
early_termination:bool=False,
150154
random_state=None):
151155

152156
self.__name__ = 'AdaUBoostClassifier'
@@ -156,6 +160,7 @@ def __init__(self,
156160
n_estimators=n_estimators,
157161
learning_rate=learning_rate,
158162
algorithm=algorithm,
163+
early_termination=early_termination,
159164
random_state=random_state)
160165

161166

imbalanced_ensemble/ensemble/reweighting/asymmetric_boost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343

4444
@Substitution(
45+
early_termination=_get_parameter_docstring('early_termination', **_properties),
4546
random_state=_get_parameter_docstring('random_state', **_properties),
4647
example=_get_example_docstring(_method_name)
4748
)
@@ -78,6 +79,8 @@ class AsymBoostClassifier(ReweightBoostClassifier):
7879
The SAMME.R algorithm typically converges faster than SAMME,
7980
achieving a lower test error with fewer boosting iterations.
8081
82+
{early_termination}
83+
8184
{random_state}
8285
8386
Attributes
@@ -137,13 +140,15 @@ def __init__(self,
137140
*,
138141
learning_rate:float=1.,
139142
algorithm:str='SAMME.R',
143+
early_termination:bool=False,
140144
random_state=None):
141145

142146
super(AsymBoostClassifier, self).__init__(
143147
base_estimator=base_estimator,
144148
n_estimators=n_estimators,
145149
learning_rate=learning_rate,
146150
algorithm=algorithm,
151+
early_termination=early_termination,
147152
random_state=random_state)
148153

149154
self.__name__ = _method_name

0 commit comments

Comments
 (0)