Skip to content

Commit 92cf7c0

Browse files
committed
fix(PredictionDataRegr): pass quantile attributes in concat
1 parent 9c38ad9 commit 92cf7c0

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

R/PredictionDataRegr.R

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,19 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint
110110
error_input("Cannot rbind predictions: Some predictions have extra data, others do not")
111111
}
112112

113-
elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")), if ("weights" %chin% names(dots[[1L]])) "weights")
113+
nn = names(dots[[1L]])
114+
elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")), if ("weights" %chin% nn) "weights")
114115
tab = map_dtr(dots, function(x) x[elems], .fill = FALSE)
115-
quantiles = do.call(rbind, map(dots, "quantiles"))
116116

117-
extra = if ("extra" %chin% names(dots[[1L]])) {
117+
quantiles = if ("quantiles" %chin% nn) {
118+
quantiles = map(dots, "quantiles")
119+
extra_attributes = attributes(quantiles[[1L]])
120+
extra_attributes = extra_attributes[names(extra_attributes) %nin% c("dim", "dimnames")]
121+
quantiles = do.call(rbind, quantiles)
122+
iwalk(extra_attributes, function(x, nm) setattr(quantiles, nm, x))
123+
quantiles
124+
}
125+
extra = if ("extra" %chin% nn) {
118126
rbindlist(map(dots, "extra"), fill = TRUE, use.names = TRUE)
119127
}
120128

R/PredictionRegr.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
6565
weights = NULL,
6666
check = TRUE,
6767
extra = NULL
68-
) {
68+
) {
6969
pdata = new_prediction_data(
7070
list(row_ids = row_ids, truth = truth, response = response, se = se, quantiles = quantiles, distr = distr, weights = weights, extra = extra),
7171
task_type = "regr"

0 commit comments

Comments
 (0)