Skip to content

Commit af8ec5c

Browse files
be-marcCopilot
andauthored
refactor: obs_loss (#1411)
* ... * ... * Update R/ResampleResult.R Co-authored-by: Copilot <[email protected]> * ... * ... * Update R/Prediction.R Co-authored-by: Copilot <[email protected]> * Update R/ResampleResult.R Co-authored-by: Copilot <[email protected]> * Update R/BenchmarkResult.R Co-authored-by: Copilot <[email protected]> * Update tests/testthat/test_PredictionClassif.R Co-authored-by: Copilot <[email protected]> * Update R/Prediction.R Co-authored-by: Copilot <[email protected]> * Update R/MeasureSimple.R Co-authored-by: Copilot <[email protected]> * ... * ... * ... * ... * ... --------- Co-authored-by: Copilot <[email protected]>
1 parent 3c6d7dc commit af8ec5c

31 files changed

+344
-145
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ Suggests:
7676
RhpcBLASctl,
7777
rpart,
7878
testthat (>= 3.3.0)
79+
Remotes:
80+
mlr-org/mlr3measures
7981
Encoding: UTF-8
8082
Config/testthat/edition: 3
8183
Config/testthat/parallel: false

R/BenchmarkResult.R

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
224224
},
225225

226226
#' @description
227-
#' Calculates the observation-wise loss via the loss function set in the
228-
#' [Measure]'s field `obs_loss`.
229-
#' Returns a `data.table()` with the columns `row_ids`, `truth`, `response` and
230-
#' one additional numeric column for each measure, named with the respective measure id.
231-
#' If there is no observation-wise loss function for the measure, the column is filled with
232-
#' `NA` values.
233-
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
234-
#' additional transformation after aggregation, in this example taking the square-root.
227+
#' Calculates the observation-wise loss via the [Measure]'s `obs_loss` method.
228+
#' Returns a `data.table()` with columns from the predictions (e.g., `row_ids`, `truth`, `response`, etc.), plus one numeric column for each measure, named with the respective measure id, and a `resample_result` column.
229+
#' If there is no observation-wise loss function for the measure, the column is filled with `NA_real_` values.
230+
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an additional transformation after aggregation, in this example taking the square-root.
231+
#'
235232
#' @param predict_sets (`character()`)\cr
236233
#' The predict sets.
237234
#' @examples

R/Measure.R

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,6 @@ Measure = R6Class("Measure",
9191
#' @template field_param_set
9292
param_set = NULL,
9393

94-
#' @field obs_loss (`function()` | `NULL`)
95-
#' Function to calculate the observation-wise loss.
96-
obs_loss = NULL,
97-
9894
#' @field trafo (`list()` | `NULL`)
9995
#' `NULL` or a list with two elements:
10096
#' * `trafo`: the transformation function applied after aggregating
@@ -138,11 +134,23 @@ Measure = R6Class("Measure",
138134
#' Creates a new instance of this [R6][R6::R6Class] class.
139135
#'
140136
#' Note that this object is typically constructed via a derived classes, e.g. [MeasureClassif] or [MeasureRegr].
141-
initialize = function(id, task_type = NA, param_set = ps(), range = c(-Inf, Inf), minimize = NA, average = "macro",
142-
aggregator = NULL, obs_loss = NULL, properties = character(), predict_type = "response",
143-
predict_sets = "test", task_properties = character(), packages = character(),
144-
label = NA_character_, man = NA_character_, trafo = NULL) {
145-
137+
initialize = function(
138+
id,
139+
task_type = NA,
140+
param_set = ps(),
141+
range = c(-Inf, Inf),
142+
minimize = NA,
143+
average = "macro",
144+
aggregator = NULL,
145+
properties = character(),
146+
predict_type = "response",
147+
predict_sets = "test",
148+
task_properties = character(),
149+
packages = character(),
150+
label = NA_character_,
151+
man = NA_character_,
152+
trafo = NULL
153+
) {
146154
self$id = assert_string(id, min.chars = 1L)
147155
self$label = assert_string(label, na.ok = TRUE)
148156
self$task_type = task_type
@@ -151,7 +159,6 @@ Measure = R6Class("Measure",
151159
self$minimize = assert_flag(minimize, na.ok = TRUE)
152160
self$average = average
153161
private$.aggregator = assert_function(aggregator, null.ok = TRUE)
154-
self$obs_loss = assert_function(obs_loss, null.ok = TRUE)
155162
self$trafo = assert_list(trafo, len = 2L, types = "function", null.ok = TRUE)
156163
if (!is.null(self$trafo)) {
157164
assert_permutation(names(trafo), c("fn", "deriv"))
@@ -311,6 +318,35 @@ Measure = R6Class("Measure",
311318
private$.aggregator(rr)
312319
}
313320
)
321+
},
322+
323+
#' @description
324+
#' Calculates the observation-wise loss.
325+
#' Returns a `numeric()` with one element for each row in the [Prediction].
326+
#' If there is no observation-wise loss function for the measure, `NA_real_` values are returned.
327+
#'
328+
#' @param prediction ([Prediction]).
329+
#' @param task ([Task]).
330+
#' @param learner ([Learner]).
331+
#'
332+
#' @return `numeric()` with one element for each row in the [Prediction].
333+
#' @examples
334+
#' task = tsk("penguins")
335+
#' learner = lrn("classif.rpart")
336+
#' learner$train(task)
337+
#' prediction = learner$predict(task)
338+
#' msr("classif.ce")$obs_loss(prediction)
339+
obs_loss = function(prediction, task = NULL, learner = NULL) {
340+
341+
if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
342+
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
343+
}
344+
345+
if ("obs_loss" %nin% self$properties) {
346+
return(rep(NA_real_, length(prediction$row_ids)))
347+
}
348+
349+
private$.obs_loss(prediction, task)
314350
}
315351
),
316352

@@ -330,7 +366,7 @@ Measure = R6Class("Measure",
330366
hash = function(rhs) {
331367
assert_ro_binding(rhs)
332368
calculate_hash(class(self), self$id, self$param_set$values, private$.score,
333-
private$.average, private$.aggregator, self$obs_loss, self$trafo,
369+
private$.average, private$.aggregator, private$.obs_loss, self$trafo,
334370
self$predict_sets, mget(private$.extra_hash, envir = self), private$.use_weights)
335371
},
336372

@@ -412,6 +448,9 @@ Measure = R6Class("Measure",
412448
.use_weights = NULL,
413449
.score = function(prediction, task, weights, ...) {
414450
stop("abstract method")
451+
},
452+
.obs_loss = function(prediction, task, ...) {
453+
stop("abstract method")
415454
}
416455
)
417456
)

R/MeasureSimple.R

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,24 @@ MeasureBinarySimple = R6Class("MeasureBinarySimple",
33
inherit = MeasureClassif,
44
public = list(
55
fun = NULL,
6+
fun_obs_loss = NULL,
67
na_value = NaN,
78
initialize = function(name, param_set = ps()) {
89
info = mlr3measures::measures[[name]]
910
weights = info$sample_weights
11+
properties = if (weights) "weights" else character()
12+
13+
if (!is.na(info$obs_loss)) {
14+
properties = c(properties, "obs_loss")
15+
self$fun_obs_loss = get(info$obs_loss, envir = asNamespace("mlr3measures"), mode = "function")
16+
}
1017

1118
super$initialize(
1219
id = paste0("classif.", name),
1320
param_set = param_set,
1421
range = c(info$lower, info$upper),
1522
minimize = info$minimize,
16-
properties = if (weights) "weights" else character(),
23+
properties = properties,
1724
predict_type = info$predict_type,
1825
task_properties = "twoclass",
1926
packages = "mlr3measures",
@@ -22,9 +29,7 @@ MeasureBinarySimple = R6Class("MeasureBinarySimple",
2229
)
2330

2431
self$fun = get(name, envir = asNamespace("mlr3measures"), mode = "function")
25-
if (!is.na(info$obs_loss)) {
26-
self$obs_loss = get(info$obs_loss, envir = asNamespace("mlr3measures"), mode = "function")
27-
}
32+
2833
if (test_list(info$trafo)) {
2934
self$trafo = info$trafo
3035
}
@@ -41,7 +46,18 @@ MeasureBinarySimple = R6Class("MeasureBinarySimple",
4146
)
4247
},
4348

44-
.extra_hash = c("fun", "na_value")
49+
.extra_hash = c("fun", "fun_obs_loss", "na_value"),
50+
51+
.obs_loss = function(prediction, ...) {
52+
truth = prediction$truth
53+
positive = levels(truth)[1L]
54+
invoke(self$fun_obs_loss,
55+
.args = self$param_set$get_values(),
56+
truth = truth,
57+
response = prediction$response,
58+
prob = prediction$prob[, positive],
59+
positive = positive)
60+
}
4561
)
4662
)
4763

@@ -50,26 +66,29 @@ MeasureClassifSimple = R6Class("MeasureClassifSimple",
5066
inherit = MeasureClassif,
5167
public = list(
5268
fun = NULL,
69+
fun_obs_loss = NULL,
5370
na_value = NaN,
5471
initialize = function(name, param_set = ps()) {
5572
info = mlr3measures::measures[[name]]
5673
weights = info$sample_weights
74+
properties = if (weights) "weights" else character()
75+
if (!is.na(info$obs_loss)) {
76+
properties = c(properties, "obs_loss")
77+
self$fun_obs_loss = get(info$obs_loss, envir = asNamespace("mlr3measures"), mode = "function")
78+
}
5779

5880
super$initialize(
5981
id = paste0("classif.", name),
6082
param_set = param_set,
6183
range = c(info$lower, info$upper),
6284
minimize = info$minimize,
63-
properties = if (weights) "weights" else character(),
85+
properties = properties,
6486
predict_type = info$predict_type,
6587
packages = "mlr3measures",
6688
label = info$title,
6789
man = paste0("mlr3::mlr_measures_classif.", name)
6890
)
6991
self$fun = get(name, envir = asNamespace("mlr3measures"), mode = "function")
70-
if (!is.na(info$obs_loss)) {
71-
self$obs_loss = get(info$obs_loss, envir = asNamespace("mlr3measures"), mode = "function")
72-
}
7392
if (test_list(info$trafo)) {
7493
self$trafo = info$trafo
7594
}
@@ -82,7 +101,15 @@ MeasureClassifSimple = R6Class("MeasureClassifSimple",
82101
na_value = self$na_value, sample_weights = weights)
83102
},
84103

85-
.extra_hash = c("fun", "na_value")
104+
.extra_hash = c("fun", "fun_obs_loss", "na_value"),
105+
106+
.obs_loss = function(prediction, ...) {
107+
invoke(self$fun_obs_loss,
108+
.args = self$param_set$get_values(),
109+
truth = prediction$truth,
110+
response = prediction$response,
111+
prob = prediction$prob)
112+
}
86113
)
87114
)
88115

@@ -91,6 +118,7 @@ MeasureRegrSimple = R6Class("MeasureRegrSimple",
91118
inherit = MeasureRegr,
92119
public = list(
93120
fun = NULL,
121+
fun_obs_loss = NULL,
94122
na_value = NaN,
95123
initialize = function(name, param_set = NULL) {
96124
if (is.null(param_set)) {
@@ -103,22 +131,25 @@ MeasureRegrSimple = R6Class("MeasureRegrSimple",
103131

104132
info = mlr3measures::measures[[name]]
105133
weights = info$sample_weights
134+
properties = if (weights) "weights" else character()
135+
if (!is.na(info$obs_loss)) {
136+
properties = c(properties, "obs_loss")
137+
self$fun_obs_loss = get(info$obs_loss, envir = asNamespace("mlr3measures"), mode = "function")
138+
}
106139

107140
super$initialize(
108141
id = paste0("regr.", name),
109142
param_set = param_set$clone(),
110143
range = c(info$lower, info$upper),
111144
minimize = info$minimize,
112-
properties = if (weights) "weights" else character(),
145+
properties = properties,
113146
predict_type = info$predict_type,
114147
packages = "mlr3measures",
115148
label = info$title,
116149
man = paste0("mlr3::mlr_measures_regr.", name)
117150
)
118151
self$fun = get(name, envir = asNamespace("mlr3measures"), mode = "function")
119-
if (!is.na(info$obs_loss)) {
120-
self$obs_loss = get(info$obs_loss, envir = asNamespace("mlr3measures"), mode = "function")
121-
}
152+
122153
if (test_list(info$trafo)) {
123154
self$trafo = info$trafo
124155
}
@@ -131,7 +162,15 @@ MeasureRegrSimple = R6Class("MeasureRegrSimple",
131162
na_value = self$na_value, sample_weights = weights)
132163
},
133164

134-
.extra_hash = c("fun", "na_value")
165+
.extra_hash = c("fun", "fun_obs_loss", "na_value"),
166+
167+
.obs_loss = function(prediction, ...) {
168+
invoke(self$fun_obs_loss,
169+
.args = self$param_set$get_values(),
170+
truth = prediction$truth,
171+
response = prediction$response,
172+
se = prediction$se)
173+
}
135174
)
136175
)
137176

R/Prediction.R

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,15 @@ Prediction = R6Class("Prediction",
9696
},
9797

9898
#' @description
99-
#' Calculates the observation-wise loss via the loss function set in the
100-
#' [Measure]'s field `obs_loss`.
101-
#' Returns a `data.table()` with the columns `row_ids`, `truth`, `response` and
102-
#' one additional numeric column for each measure, named with the respective measure id.
103-
#' If there is no observation-wise loss function for the measure, the column is filled with
104-
#' `NA` values.
105-
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
106-
#' additional transformation after aggregation, in this example taking the square-root.
99+
#' Calculates the observation-wise loss via the [Measure]'s `obs_loss` method.
100+
#' Returns a `data.table()` with the columns of the matching [Prediction] object plus one additional numeric column for each measure, named with the respective measure id.
101+
#' If there is no observation-wise loss function for the measure, the column is filled with `NA_real_` values.
102+
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an additional transformation after aggregation, in this example taking the square-root.
107103
obs_loss = function(measures = NULL) {
108104
measures = assert_measures(as_measures(measures, task_type = self$task_type))
109-
get_obs_loss(as.data.table(self), measures)
105+
tab = as.data.table(self)
106+
walk(measures, function(m) set(tab, j = m$id, value = m$obs_loss(prediction = self)))
107+
tab[]
110108
},
111109

112110

R/ResampleResult.R

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,23 +190,20 @@ ResampleResult = R6Class("ResampleResult",
190190
},
191191

192192
#' @description
193-
#' Calculates the observation-wise loss via the loss function set in the
194-
#' [Measure]'s field `obs_loss`.
195-
#' Returns a `data.table()` with the columns of the matching [Prediction] object plus
196-
#' one additional numeric column for each measure, named with the respective measure id.
197-
#' If there is no observation-wise loss function for the measure, the column is filled with
198-
#' `NA` values.
199-
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
200-
#' additional transformation after aggregation, in this example taking the square-root.
193+
#' Calculates the observation-wise loss via the [Measure]'s `obs_loss` method.
194+
#' Returns a `data.table()` with an `iteration` column plus one numeric column for each measure, named with the respective measure id.
195+
#' If there is no observation-wise loss function for the measure, the column is filled with `NA_real_` values.
196+
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an additional transformation after aggregation, in this example taking the square-root.
201197
#'
202198
#' @param predict_sets (`character()`)\cr
203199
#' The predict sets.
204200
#' @examples
205201
#' rr$obs_loss(msr("classif.acc"))
206202
obs_loss = function(measures = NULL, predict_sets = "test") {
207203
measures = assert_measures(as_measures(measures, task_type = self$task_type))
208-
tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration")
209-
get_obs_loss(tab, measures)
204+
map_dtr(self$predictions(predict_sets), function(pred) {
205+
pred$obs_loss(measures)
206+
}, .idcol = "iteration")
210207
},
211208

212209
#' @description

R/helper.R

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,6 @@ assert_validate = function(x) {
6262
assert_choice(x, c("predefined", "test"), null.ok = TRUE)
6363
}
6464

65-
66-
get_obs_loss = function(tab, measures) {
67-
for (measure in measures) {
68-
fun = measure$obs_loss
69-
value = if (is.function(fun)) {
70-
args = intersect(names(tab), names(formals(fun)))
71-
do.call(fun, tab[, args, with = FALSE])
72-
} else {
73-
NA_real_
74-
}
75-
76-
set(tab, j = measure$id, value = value)
77-
}
78-
79-
tab[]
80-
}
81-
8265
# Generalization of quantile(type = 7) for weighted data.
8366

8467
quantile_weighted = function(x, probs, na.rm = FALSE, weights = NULL, digits = 7L, continuous = TRUE) {

R/mlr_reflections.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ local({
147147
mlr_reflections$resampling_properties = c("duplicated_ids", "weights")
148148

149149
### Measures
150-
tmp = c("na_score", "requires_task", "requires_learner", "requires_model", "requires_train_set", "weights", "primary_iters", "requires_no_prediction")
150+
tmp = c("na_score", "requires_task", "requires_learner", "requires_model", "requires_train_set", "weights", "primary_iters", "requires_no_prediction", "obs_loss")
151151
mlr_reflections$measure_properties = list(
152152
classif = tmp,
153153
regr = tmp

0 commit comments

Comments
 (0)