Skip to content

Commit 60f1286

Browse files
ENH add evaluate_print() utility
1 parent 325f390 commit 60f1286

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

imbalanced_ensemble/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from ._docstring import Substitution
66

7+
from ._evaluate import evaluate_print
8+
79
from ._validation import check_neighbors_object
810
from ._validation import check_target_type
911
from ._validation import check_sampling_strategy
@@ -15,6 +17,7 @@
1517
from ._validation_param import check_balancing_schedule
1618

1719
__all__ = [
20+
"evaluate_print",
1821
"check_neighbors_object",
1922
"check_sampling_strategy",
2023
"check_target_type",
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from sklearn.metrics import balanced_accuracy_score, f1_score
2+
from imbalanced_ensemble.metrics import geometric_mean_score
3+
4+
DEFAULT_METRICS = {
5+
'balanced Acc': (balanced_accuracy_score, {}),
6+
'macro Fscore': (f1_score, {'average':'macro'}),
7+
'macro Gmean': (geometric_mean_score, {'average':'macro'}),
8+
}
9+
10+
def evaluate_print(y_true, y_pred, head:str="",
11+
eval_metrics:dict=DEFAULT_METRICS,
12+
print_str:bool=True, return_str:bool=False):
13+
"""Evaluate and print the predictive performance with respect to
14+
the given metrics.
15+
16+
Returns a string of evaluation results.
17+
18+
Parameters
19+
----------
20+
y_true : 1d array-like, or label indicator array / sparse matrix
21+
Ground truth (correct) target values.
22+
23+
y_pred : 1d array-like, or label indicator array / sparse matrix
24+
Estimated targets as returned by a classifier.
25+
26+
head : string, default=""
27+
Head of the returned string, for example, the name of the predictor.
28+
29+
eval_metrics : dict, default=None
30+
Metric(s) used for evaluation during the ensemble training process.
31+
32+
- If ``None``, use 3 default metrics:
33+
34+
- ``'balanced Acc'``:
35+
``sklearn.metrics.balanced_accuracy_score()``
36+
- ``'macro F1'``:
37+
``sklearn.metrics.f1_score(average='macro')``
38+
- ``'macro Gmean'``:
39+
``imbens.metrics.geometric_mean_score(average='macro')``
40+
41+
- If ``dict``, the keys should be strings corresponding to evaluation
42+
metrics' names. The values should be tuples corresponding to the metric
43+
function (``callable``) and additional kwargs (``dict``).
44+
45+
- The metric function should at least take 2 named/keyword arguments,
46+
``y_true`` and one of [``y_pred``, ``y_score``], and returns a float
47+
as the evaluation score. Keyword arguments:
48+
49+
- ``y_true``, 1d-array of shape (n_samples,), true labels or binary
50+
label indicators corresponds to ground truth (correct) labels.
51+
- When using ``y_pred``, input will be 1d-array of shape (n_samples,)
52+
corresponds to predicted labels, as returned by a classifier.
53+
- When using ``y_score``, input will be 2d-array of shape (n_samples,
54+
n_classes,) corresponds to probability estimates provided by the
55+
predict_proba method. In addition, the order of the class scores
56+
must correspond to the order of ``labels``, if provided in the metric
57+
function, or else to the numerical or lexicographical order of the
58+
labels in ``y_true``.
59+
60+
- The metric additional kwargs should be a dictionary that specifies
61+
the additional arguments that need to be passed into the metric function.
62+
63+
print_str : bool, defaul=True
64+
Whether to print the results to stdout. If False, disable print.
65+
66+
return_str : bool, defaul=False
67+
Whether to return the result string. If True, returns it.
68+
69+
Returns
70+
-------
71+
result_str : string or NoneType
72+
73+
"""
74+
result_str = head + " "
75+
for metric_name, (metric_func, kwargs) in eval_metrics.items():
76+
score = metric_func(y_true, y_pred, **kwargs)
77+
result_str += "{}: {:.3f} | ".format(metric_name, score)
78+
if print_str:
79+
print (result_str.rstrip(" |"))
80+
if return_str:
81+
return result_str.rstrip(" |")

0 commit comments

Comments
 (0)