Skip to content

Commit ed25700

Browse files
authored
fix(sklearn): probas_pred got removed since 1.7 (#879)
1 parent f11c5fb commit ed25700

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies = [
4444

4545
[project.optional-dependencies]
4646
image = ["numpy", "pillow"]
47-
sklearn = ["scikit-learn"]
47+
sklearn = ["scikit-learn>=1.5.0"]
4848
plots = ["scikit-learn", "pandas", "numpy"]
4949
markdown = ["matplotlib"]
5050
tests = [

src/dvclive/plots/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def dump(self, val, **kwargs) -> None:
5454
from sklearn import metrics
5555

5656
precision, recall, prc_thresholds = metrics.precision_recall_curve(
57-
y_true=val[0], probas_pred=val[1], **kwargs
57+
y_true=val[0], y_score=val[1], **kwargs
5858
)
5959

6060
prc = {

tests/frameworks/test_huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
6767

6868

6969
class RegressionPreTrainedModel(PreTrainedModel):
70-
config_class = RegressionModelConfig
70+
config_class = RegressionModelConfig # type: ignore[assignment]
7171
base_model_prefix = "regression"
7272

7373
def __init__(self, config):

tests/plots/test_sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_log_prc_curve(tmp_dir, y_true_y_pred_y_score, mocker):
7878

7979
live.log_sklearn_plot("precision_recall", y_true, y_score)
8080

81-
spy.assert_called_once_with(y_true=y_true, probas_pred=y_score)
81+
spy.assert_called_once_with(y_true=y_true, y_score=y_score)
8282
assert (out / "precision_recall.json").exists()
8383

8484

0 commit comments

Comments
 (0)