Skip to content

Commit 21dc3a0

Browse files
committed
fix(PredictionDataRegr): pass quantile attributes in concat
1 parent 9c38ad9 commit 21dc3a0

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

R/PredictionDataRegr.R

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,18 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint
110110
error_input("Cannot rbind predictions: Some predictions have extra data, others do not")
111111
}
112112

113-
elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")), if ("weights" %chin% names(dots[[1L]])) "weights")
113+
nn = names(dots[[1L]])
114+
elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")), if ("weights" %chin% nn) "weights")
114115
tab = map_dtr(dots, function(x) x[elems], .fill = FALSE)
115-
quantiles = do.call(rbind, map(dots, "quantiles"))
116116

117-
extra = if ("extra" %chin% names(dots[[1L]])) {
117+
quantiles = if ("quantiles" %chin% nn) {
118+
quantiles = map(dots, "quantiles")
119+
attrs = attributes(quantiles[[1L]])
120+
quantiles = do.call(rbind, quantiles)
121+
setattr(quantiles, "probs", attrs$props)
122+
setattr(quantiles, "response", attrs$response)
123+
}
124+
extra = if ("extra" %chin% nn) {
118125
rbindlist(map(dots, "extra"), fill = TRUE, use.names = TRUE)
119126
}
120127

R/PredictionRegr.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
6565
weights = NULL,
6666
check = TRUE,
6767
extra = NULL
68-
) {
68+
) {
6969
pdata = new_prediction_data(
7070
list(row_ids = row_ids, truth = truth, response = response, se = se, quantiles = quantiles, distr = distr, weights = weights, extra = extra),
7171
task_type = "regr"

0 commit comments

Comments
 (0)