Skip to content

Commit bf8305b

Browse files
lona-kbe-marc
andauthored
feat: use new condition system for errors and warnings (#1398)
* fix: change classif to regr * add error classes with error_*() * change warnings to warning_* * fix tests to work with <*> * Apply suggestions from code review * Update R/as_data_backend.R * Update R/Learner.R * Update R/LearnerRegr.R * Update R/LearnerRegr.R * ... * ... * ... * ... * ... --------- Co-authored-by: Marc Becker <[email protected]> Co-authored-by: be-marc <[email protected]>
1 parent 05792a6 commit bf8305b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+258
-213
lines changed

R/BenchmarkResult.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
124124
if (!is.null(bmr)) {
125125
assert_benchmark_result(bmr)
126126
if (private$.data$iterations() && self$task_type != bmr$task_type) {
127-
stopf("BenchmarkResult is of task type '%s', but must be '%s'", bmr$task_type, self$task_type)
127+
error_input("BenchmarkResult is of task type '%s', but must be '%s'", bmr$task_type, self$task_type)
128128
}
129129

130130
private$.data$combine(get_private(bmr)$.data)
@@ -425,8 +425,8 @@ BenchmarkResult = R6Class("BenchmarkResult",
425425
resample_result = function(i = NULL, uhash = NULL, task_id = NULL, learner_id = NULL,
426426
resampling_id = NULL) {
427427
uhash = private$.get_uhashes(i, uhash, learner_id, task_id, resampling_id)
428-
if (length(uhash) != 1L) {
429-
stopf("Method requires selecting exactly one ResampleResult, but got %s",
428+
if (length(uhash) != 1) {
429+
error_input("Method requires selecting exactly one ResampleResult, but got %s",
430430
length(uhash))
431431
}
432432
ResampleResult$new(private$.data, view = uhash)
@@ -598,7 +598,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
598598
resampling_ids = resampling_ids), is.null)
599599

600600
if (sum(!is.null(i), !is.null(uhashes), length(args) > 0L) > 1) {
601-
stopf("At most one of `i`, `uhash`, or IDs can be provided.")
601+
error_input("At most one of `i`, `uhash`, or IDs can be provided.")
602602
}
603603
if (!is.null(i)) {
604604
uhashes = self$uhashes
@@ -609,7 +609,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
609609
uhashes = invoke(match.fun("uhashes"), bmr = self, .args = args)
610610
}
611611
if (length(uhashes) == 0L) {
612-
stopf("No resample results found for the given arguments.")
612+
error_input("No resample results found for the given arguments.")
613613
}
614614
uhashes
615615
},
@@ -714,7 +714,7 @@ uhash = function(bmr, learner_id = NULL, task_id = NULL, resampling_id = NULL) {
714714
assert_string(resampling_id, null.ok = TRUE)
715715
uhash = uhashes(bmr, learner_id, task_id, resampling_id)
716716
if (length(uhash) != 1) {
717-
stopf("Expected exactly one uhash, got %s", length(uhash))
717+
error_input("Expected exactly one uhash, got %s", length(uhash))
718718
}
719719
uhash
720720
}

R/DataBackendCbind.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ DataBackendCbind = R6Class("DataBackendCbind", inherit = DataBackend, cloneable
77
pk = b1$primary_key
88

99
if (pk != b2$primary_key) {
10-
stopf("All backends to cbind must have the primary_key '%s'", pk)
10+
error_input("All backends to cbind must have the primary_key '%s'", pk)
1111
}
1212

1313
super$initialize(list(b1 = b1, b2 = b2), pk)

R/DataBackendDataTable.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
4747
super$initialize(setkeyv(data, primary_key), primary_key)
4848
ii = match(primary_key, names(data))
4949
if (is.na(ii)) {
50-
stopf("Primary key '%s' not in 'data'", primary_key)
50+
error_input("Primary key '%s' not in 'data'", primary_key)
5151
}
5252
private$.cache = set_names(replace(rep(NA, ncol(data)), ii, FALSE), names(data))
5353
},

R/DataBackendRbind.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ DataBackendRbind = R6Class("DataBackendRbind", inherit = DataBackend, cloneable
77
pk = b1$primary_key
88

99
if (pk != b2$primary_key) {
10-
stopf("All backends to rbind must have the primary_key '%s'", pk)
10+
error_input("All backends to rbind must have the primary_key '%s'", pk)
1111
}
1212

1313
super$initialize(list(b1 = b1, b2 = b2), pk)

R/DataBackendRename.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
1616
new = new[ii]
1717

1818
if (self$primary_key %chin% old) {
19-
stopf("Renaming the primary key is not supported")
19+
error_input("Renaming the primary key is not supported")
2020
}
2121

2222

2323
resulting_names = map_values(b$colnames, old, new)
2424
dup = anyDuplicated(resulting_names)
2525
if (dup > 0L) {
26-
stopf("Duplicated column name after rename: %s", resulting_names[dup])
26+
error_input("Duplicated column name after rename: %s", resulting_names[dup])
2727
}
2828

2929
self$old = old

R/HotstartStack.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ HotstartStack = R6Class("HotstartStack",
8787

8888
walk(learners, function(learner) {
8989
if (!is.null(get0("validate", learner))) {
90-
stopf("Hotstart learners that did validation is currently not supported.")
90+
error_input("Hotstart learners that did validation is currently not supported.")
9191
} else if (is.null(learner$model)) {
92-
stopf("Learners must be trained before adding them to the hotstart stack.")
92+
error_input("Learners must be trained before adding them to the hotstart stack.")
9393
} else if (is_marshaled_model(learner$model)) {
94-
stopf("Learners must be unmarshaled before adding them to the hotstart stack.")
94+
error_input("Learners must be unmarshaled before adding them to the hotstart stack.")
9595
}
9696
})
9797

R/Learner.R

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -375,14 +375,14 @@ Learner = R6Class("Learner",
375375
predict = function(task, row_ids = NULL) {
376376
# improve error message for the common mistake of passing a data.frame here
377377
if (is.data.frame(task)) {
378-
stopf("To predict on data.frames, use the method `$predict_newdata()` instead of `$predict()`")
378+
error_input("To predict on data.frames, use the method `$predict_newdata()` instead of `$predict()`")
379379
}
380380
task = assert_task(as_task(task))
381381
assert_predictable(task, self)
382382
row_ids = assert_row_ids(row_ids, task = task, null.ok = TRUE)
383383

384384
if (is.null(self$state$model) && is.null(self$state$fallback_state$model)) {
385-
stopf("Cannot predict, Learner '%s' has not been trained yet", self$id)
385+
error_input("Cannot predict, Learner '%s' has not been trained yet", self$id)
386386
}
387387

388388
# we need to marshal for call-r prediction and parallel prediction, but afterwards we reset the model
@@ -452,7 +452,7 @@ Learner = R6Class("Learner",
452452
predict_newdata = function(newdata, task = NULL) {
453453
if (is.null(task)) {
454454
if (is.null(self$state$train_task)) {
455-
stopf("No task stored, and no task provided")
455+
error_input("No task stored, and no task provided")
456456
}
457457
task = self$state$train_task$clone()
458458
} else {
@@ -618,7 +618,7 @@ Learner = R6Class("Learner",
618618
fallback$id, self$id, str_collapse(missing_properties), class = "Mlr3WarningConfigFallbackProperties")
619619
}
620620
} else if (method == "none" && !is.null(fallback)) {
621-
stopf("Fallback learner must be `NULL` if encapsulation is set to `none`.")
621+
error_input("Fallback learner must be `NULL` if encapsulation is set to `none`.")
622622
}
623623

624624
private$.encapsulation = c(train = method, predict = method)
@@ -665,7 +665,7 @@ Learner = R6Class("Learner",
665665
for (i in seq_along(new_values)) {
666666
nn = ndots[[i]]
667667
if (!exists(nn, envir = self, inherits = FALSE)) {
668-
stopf("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
668+
error_config("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
669669
nn, class(self)[1L], did_you_mean(nn, c(param_ids, setdiff(names(self), ".__enclos_env__")))) # nolint
670670
}
671671
self[[nn]] = new_values[[i]]
@@ -681,10 +681,10 @@ Learner = R6Class("Learner",
681681
#' If set to `"error"`, an error is thrown, otherwise all features are returned.
682682
selected_features = function() {
683683
if (is.null(self$model)) {
684-
stopf("No model stored")
684+
error_input("No model stored")
685685
}
686686
if (private$.selected_features_impute == "error") {
687-
stopf("Learner does not support feature selection")
687+
error_input("Learner does not support feature selection")
688688
} else {
689689
self$state$feature_names
690690
}
@@ -790,15 +790,15 @@ Learner = R6Class("Learner",
790790

791791
assert_string(rhs, .var.name = "predict_type")
792792
if (rhs %nin% self$predict_types) {
793-
stopf("Learner '%s' does not support predict type '%s'", self$id, rhs)
793+
error_input("Learner '%s' does not support predict type '%s'", self$id, rhs) # TODO error_learner?
794794
}
795795
private$.predict_type = rhs
796796
},
797797

798798
#' @template field_param_set
799799
param_set = function(rhs) {
800800
if (!missing(rhs) && !identical(rhs, private$.param_set)) {
801-
stopf("param_set is read-only.")
801+
error_input("param_set is read-only.")
802802
}
803803
private$.param_set
804804
},
@@ -866,7 +866,7 @@ Learner = R6Class("Learner",
866866
# return: Numeric vector of weights or `no_weights_val` (default NULL)
867867
.get_weights = function(task, no_weights_val = NULL) {
868868
if ("weights" %nin% self$properties) {
869-
stop("private$.get_weights should not be used in Learners that do not have the 'weights' property.")
869+
error_mlr3("private$.get_weights should not be used in Learners that do not have the 'weights' property.")
870870
}
871871
if (self$use_weights == "use" && "weights_learner" %in% task$properties) {
872872
task$weights_learner$weight
@@ -916,7 +916,7 @@ default_values.Learner = function(x, search_space, task, ...) { # nolint
916916
values = default_values(x$param_set)
917917

918918
if (any(search_space$ids() %nin% names(values))) {
919-
stopf("Could not find default values for the following parameters: %s",
919+
error_input("Could not find default values for the following parameters: %s",
920920
str_collapse(setdiff(search_space$ids(), names(values))))
921921
}
922922

R/LearnerClassif.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ LearnerClassif = R6Class("LearnerClassif", inherit = Learner,
7777
#'
7878
#' @return `list()` with elements `"response"` or `"prob"` depending on the predict type.
7979
predict_newdata_fast = function(newdata, task = NULL) {
80-
if (is.null(task) && is.null(self$state$train_task)) stopf("No task stored, and no task provided")
80+
if (is.null(task) && is.null(self$state$train_task)) error_input("No task stored, and no task provided")
8181
feature_names = self$state$train_task$feature_names %??% task$feature_names
8282
class_names = self$state$train_task$class_names %??% task$class_names
8383

R/LearnerClassifDebug.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
113113
#' @return Named `numeric()`.
114114
importance = function() {
115115
if (is.null(self$model)) {
116-
stopf("No model stored")
116+
error_input("No model stored")
117117
}
118118
fns = self$state$feature_names
119119
set_names(rep(0, length(fns)), fns)
@@ -124,7 +124,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
124124
#' @return `character()`.
125125
selected_features = function() {
126126
if (is.null(self$model)) {
127-
stopf("No model stored")
127+
error_input("No model stored")
128128
}
129129
character(0)
130130
}
@@ -180,10 +180,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
180180
message("Message from classif.debug->train()")
181181
}
182182
if (roll("warning_train")) {
183-
warningf("Warning from classif.debug->train()")
183+
warning_mlr3("Warning from classif.debug->train()")
184184
}
185185
if (roll("error_train")) {
186-
stopf("Error from classif.debug->train()")
186+
error_learner_train("Error from classif.debug->train()")
187187
}
188188
if (roll("segfault_train")) {
189189
get("attach")(structure(list(), class = "UserDefinedDatabase"))
@@ -192,7 +192,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
192192
valid_truth = if (!is.null(task$internal_valid_task)) task$internal_valid_task$truth()
193193

194194
if (isTRUE(pv$early_stopping) && is.null(valid_truth)) {
195-
stopf("Early stopping is only possible when a validation task is present.")
195+
error_config("Early stopping is only possible when a validation task is present.")
196196
}
197197

198198
model = list(
@@ -248,7 +248,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
248248
},
249249
.predict = function(task) {
250250
if (!is.null(self$model$marshal_pid) && self$model$marshal_pid != Sys.getpid()) {
251-
stopf("Model was not unmarshaled correctly")
251+
error_mlr3("Model was not unmarshaled correctly")
252252
}
253253
n = task$nrow
254254
pv = self$param_set$get_values(tags = "predict")
@@ -265,10 +265,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
265265
message("Message from classif.debug->predict()")
266266
}
267267
if (roll("warning_predict")) {
268-
warningf("Warning from classif.debug->predict()")
268+
warning_mlr3("Warning from classif.debug->predict()")
269269
}
270270
if (roll("error_predict")) {
271-
stopf("Error from classif.debug->predict()")
271+
error_learner_predict("Error from classif.debug->predict()")
272272
}
273273
if (roll("segfault_predict")) {
274274
get("attach")(structure(list(), class = "UserDefinedDatabase"))

R/LearnerClassifFeatureless.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
5454
#' @return Named `numeric()`.
5555
importance = function() {
5656
if (is.null(self$model)) {
57-
stopf("No model stored")
57+
error_learner("No model stored")
5858
}
5959
fn = self$model$features
6060
named_vector(fn, 0)
@@ -65,7 +65,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
6565
#' @return `character(0)`.
6666
selected_features = function() {
6767
if (is.null(self$model)) {
68-
stopf("No model stored")
68+
error_learner("No model stored")
6969
}
7070
character()
7171
}

0 commit comments

Comments
 (0)