diff --git a/NAMESPACE b/NAMESPACE
index c5dab6d..91cc64b 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -5,6 +5,7 @@ S3method(bundle,H2OBinomialModel)
S3method(bundle,H2OMultinomialModel)
S3method(bundle,H2ORegressionModel)
S3method(bundle,bart)
+S3method(bundle,catboost.Model)
S3method(bundle,default)
S3method(bundle,keras.engine.training.Model)
S3method(bundle,luz_module_fitted)
diff --git a/NEWS.md b/NEWS.md
index f51fbeb..10206da 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,5 +1,7 @@
# bundle (development version)
+* Added bundle method for catboost models and, by extension, `parsnip::boost_tree(engine = "catboost")` via bonsai (#82).
+
# bundle 0.1.3
* Updated to support new versions of xgboost models (#75).
diff --git a/R/bundle_catboost.R b/R/bundle_catboost.R
new file mode 100644
index 0000000..3a92deb
--- /dev/null
+++ b/R/bundle_catboost.R
@@ -0,0 +1,73 @@
+#' @templateVar class a `catboost.Model`
+#' @template title_desc
+#'
+#' @templateVar outclass `bundled_catboost.Model`
+#' @templateVar default .
+#' @template return_bundle
+#' @family bundlers
+#'
+#' @param x A `catboost.Model` object returned from `catboost::catboost.train()`.
+#' @template param_unused_dots
+#' @rdname bundle_catboost
+#' @seealso This method stores the raw serialized model bytes from the
+#' catboost model object and restores the C++ handle on unbundle.
+#' @template butcher_details
+#' @examplesIf rlang::is_installed(c("catboost", "parsnip", "bonsai"))
+#' # fit model and bundle ------------------------------------------------
+#' library(parsnip)
+#' library(bonsai)
+#'
+#' set.seed(1)
+#'
+#' mod <- boost_tree(trees = 10) %>%
+#' set_engine("catboost", verbose = 0) %>%
+#' set_mode("classification") %>%
+#'
+#' fit(Species ~ ., data = iris)
+#' # extract the underlying catboost model
+#' catboost_model <- mod$fit
+#'
+#' model_bundle <- bundle(catboost_model)
+#'
+#' # then, after saveRDS + readRDS or passing to a new session ----------
+#' model_unbundled <- unbundle(model_bundle)
+#' @aliases bundle.catboost.Model
+#' @method bundle catboost.Model
+#' @export
+bundle.catboost.Model <- function(x, ...) {
+ rlang::check_installed("catboost")
+ rlang::check_dots_empty()
+
+ # Save model to temp file and read raw bytes
+ tmp_file <- withr::local_tempfile(fileext = ".cbm")
+ rlang::eval_tidy(rlang::call2(
+ "catboost.save_model",
+ x,
+ tmp_file,
+ .ns = "catboost"
+ ))
+ object <- readBin(tmp_file, "raw", file.size(tmp_file))
+
+ bundle_constr(
+ object = object,
+ situate = situate_constr(function(object) {
+ # Write raw bytes to temp file and load model
+ tmp_file <- withr::local_tempfile(fileext = ".cbm")
+ writeBin(object, tmp_file)
+ model <- rlang::eval_tidy(rlang::call2(
+ "catboost.load_model",
+ tmp_file,
+ .ns = "catboost"
+ ))
+
+ # Restore metadata
+ model$feature_importances <- !!x$feature_importances
+ model$tree_count <- !!x$tree_count
+ model$learning_rate <- !!x$learning_rate
+ model$feature_count <- !!x$feature_count
+
+ model
+ }),
+ desc_class = class(x)[1]
+ )
+}
diff --git a/man/bundle.Rd b/man/bundle.Rd
index 18e304c..7350db3 100644
--- a/man/bundle.Rd
+++ b/man/bundle.Rd
@@ -62,6 +62,7 @@ then re-loaded and \code{unbundle()}d in a new R session for use in prediction.
Other bundlers:
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_bart.Rd b/man/bundle_bart.Rd
index 545db0c..9174182 100644
--- a/man/bundle_bart.Rd
+++ b/man/bundle_bart.Rd
@@ -97,6 +97,7 @@ fit_unbundled_preds <- predict(fit_unbundled, mtcars)
Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_caret.Rd b/man/bundle_caret.Rd
index c4a4a50..82a3c12 100644
--- a/man/bundle_caret.Rd
+++ b/man/bundle_caret.Rd
@@ -110,6 +110,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_catboost.Rd b/man/bundle_catboost.Rd
new file mode 100644
index 0000000..e32144b
--- /dev/null
+++ b/man/bundle_catboost.Rd
@@ -0,0 +1,117 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/bundle_catboost.R
+\name{bundle.catboost.Model}
+\alias{bundle.catboost.Model}
+\title{Bundle a \code{catboost.Model} object}
+\usage{
+\method{bundle}{catboost.Model}(x, ...)
+}
+\arguments{
+\item{x}{A \code{catboost.Model} object returned from \code{catboost::catboost.train()}.}
+
+\item{...}{Not used in this bundler and included for compatibility with
+the generic only. Additional arguments passed to this method will return
+an error.}
+}
+\value{
+A bundle object with subclass \code{bundled_catboost.Model}.
+
+Bundles are a list subclass with two components:
+
+\item{object}{An R object. Gives the output of native serialization
+methods from the model-supplying package, sometimes with additional
+classes or attributes that aid portability. This is often
+a \link[base:raw]{raw} object.}
+\item{situate}{A function. The \code{situate()} function is defined when
+\code{\link[=bundle]{bundle()}} is called, though is a loose analogue of an \code{\link[=unbundle]{unbundle()}} S3
+method for that object. Since the function is defined on \code{\link[=bundle]{bundle()}}, it
+has access to references and dependency information that can
+be saved alongside the \code{object} component. Calling \code{\link[=unbundle]{unbundle()}} on a
+bundled object \code{x} calls \code{x$situate(x$object)}, returning the
+unserialized version of \code{object}. \code{situate()} will also restore needed
+references, such as server instances and environmental variables.}
+
+Bundles are R objects that represent a "standalone" version of their
+analogous model object. Thus, bundles are ready for saving to a file; saving
+with \code{\link[base:readRDS]{base::saveRDS()}} is our recommended serialization strategy for bundles,
+unless documented otherwise for a specific method.
+
+To restore the original model object \code{x} in a new environment, load its
+bundle with \code{\link[base:readRDS]{base::readRDS()}} and run \code{\link[=unbundle]{unbundle()}} on it. The output
+of \code{\link[=unbundle]{unbundle()}} is a model object that is ready to \code{\link[=predict]{predict()}} on new data,
+and other restored functionality (like plotting or summarizing) is supported
+as a side effect only.
+
+The bundle package wraps native serialization methods from model-supplying
+packages. Between versions, those model-supplying packages may change their
+native serialization methods, possibly introducing problems with re-loading
+objects serialized with previous package versions. The bundle package does
+not provide checks for these sorts of changes, and ought to be used in
+conjunction with tooling for managing and monitoring model environments
+like \link[vetiver:vetiver-package]{vetiver} or \link[renv:renv-package]{renv}.
+
+See \code{vignette("bundle")} for more information on bundling and its motivation.
+}
+\description{
+Bundling a model prepares it to be saved to a file and later
+restored for prediction in a new R session. See the 'Value' section for
+more information on bundles and their usage.
+}
+\section{bundle and butcher}{
+
+The \href{https://butcher.tidymodels.org/}{butcher} package allows you to remove
+parts of a fitted model object that are not needed for prediction.
+
+This bundle method is compatible with pre-butchering. That is, for a
+fitted model \code{x}, you can safely call:
+
+\if{html}{\out{
}}\preformatted{res <-
+ x |>
+ butcher() |>
+ bundle()
+}\if{html}{\out{
}}
+
+and predict with the output of \code{unbundle(res)} in a new R session.
+}
+
+\examples{
+\dontshow{if (rlang::is_installed(c("catboost", "parsnip", "bonsai"))) withAutoprint(\{ # examplesIf}
+# fit model and bundle ------------------------------------------------
+library(parsnip)
+library(bonsai)
+
+set.seed(1)
+
+mod <- boost_tree(trees = 10) \%>\%
+ set_engine("catboost", verbose = 0) \%>\%
+ set_mode("classification") \%>\%
+
+ fit(Species ~ ., data = iris)
+# extract the underlying catboost model
+catboost_model <- mod$fit
+
+model_bundle <- bundle(catboost_model)
+
+# then, after saveRDS + readRDS or passing to a new session ----------
+model_unbundled <- unbundle(model_bundle)
+\dontshow{\}) # examplesIf}
+}
+\seealso{
+This method stores the raw serialized model bytes from the
+catboost model object and restores the C++ handle on unbundle.
+
+Other bundlers:
+\code{\link{bundle}()},
+\code{\link{bundle.H2OAutoML}()},
+\code{\link{bundle.bart}()},
+\code{\link{bundle.keras.engine.training.Model}()},
+\code{\link{bundle.luz_module_fitted}()},
+\code{\link{bundle.model_fit}()},
+\code{\link{bundle.model_stack}()},
+\code{\link{bundle.recipe}()},
+\code{\link{bundle.step_umap}()},
+\code{\link{bundle.train}()},
+\code{\link{bundle.workflow}()},
+\code{\link{bundle.xgb.Booster}()}
+}
+\concept{bundlers}
diff --git a/man/bundle_embed.Rd b/man/bundle_embed.Rd
index f4fab8d..f60689d 100644
--- a/man/bundle_embed.Rd
+++ b/man/bundle_embed.Rd
@@ -104,6 +104,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_h2o.Rd b/man/bundle_h2o.Rd
index b62f97d..7095238 100644
--- a/man/bundle_h2o.Rd
+++ b/man/bundle_h2o.Rd
@@ -110,6 +110,7 @@ These methods wrap \code{\link[h2o:h2o.save_mojo]{h2o::h2o.save_mojo()}} and
Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_keras.Rd b/man/bundle_keras.Rd
index a0fc958..024ed8b 100644
--- a/man/bundle_keras.Rd
+++ b/man/bundle_keras.Rd
@@ -123,6 +123,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
\code{\link{bundle.model_stack}()},
diff --git a/man/bundle_parsnip.Rd b/man/bundle_parsnip.Rd
index aa2f528..399152e 100644
--- a/man/bundle_parsnip.Rd
+++ b/man/bundle_parsnip.Rd
@@ -108,6 +108,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_stack}()},
diff --git a/man/bundle_recipe.Rd b/man/bundle_recipe.Rd
index 480f879..5a0beec 100644
--- a/man/bundle_recipe.Rd
+++ b/man/bundle_recipe.Rd
@@ -69,6 +69,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_stacks.Rd b/man/bundle_stacks.Rd
index 1c3981c..6ca0b38 100644
--- a/man/bundle_stacks.Rd
+++ b/man/bundle_stacks.Rd
@@ -88,6 +88,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_torch.Rd b/man/bundle_torch.Rd
index f71d20a..dea62e3 100644
--- a/man/bundle_torch.Rd
+++ b/man/bundle_torch.Rd
@@ -151,6 +151,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.model_fit}()},
\code{\link{bundle.model_stack}()},
diff --git a/man/bundle_workflows.Rd b/man/bundle_workflows.Rd
index a69497c..32b9789 100644
--- a/man/bundle_workflows.Rd
+++ b/man/bundle_workflows.Rd
@@ -115,6 +115,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/man/bundle_xgboost.Rd b/man/bundle_xgboost.Rd
index b8ed244..161e9fb 100644
--- a/man/bundle_xgboost.Rd
+++ b/man/bundle_xgboost.Rd
@@ -111,6 +111,7 @@ Other bundlers:
\code{\link{bundle}()},
\code{\link{bundle.H2OAutoML}()},
\code{\link{bundle.bart}()},
+\code{\link{bundle.catboost.Model}()},
\code{\link{bundle.keras.engine.training.Model}()},
\code{\link{bundle.luz_module_fitted}()},
\code{\link{bundle.model_fit}()},
diff --git a/tests/testthat/test_bundle_catboost.R b/tests/testthat/test_bundle_catboost.R
new file mode 100644
index 0000000..2b06357
--- /dev/null
+++ b/tests/testthat/test_bundle_catboost.R
@@ -0,0 +1,180 @@
+test_that("bundling + unbundling catboost fits", {
+ skip_if_not_installed("catboost")
+ skip_if_not_installed("bonsai")
+ skip_if_not_installed("parsnip")
+ skip_if_not_installed("butcher")
+
+ library(parsnip)
+ library(bonsai)
+
+ # define a function to fit a model -------------------------------------------
+ fit_model <- function() {
+ library(parsnip)
+ library(bonsai)
+
+ model <- boost_tree(trees = 10) %>%
+ set_engine("catboost", verbose = 0, random_seed = 1) %>%
+ set_mode("classification") %>%
+ fit(Species ~ ., data = iris)
+
+ model$fit
+ }
+
+ # pass fit fn to a new session, fit, bundle, return bundle -------------------
+ mod_bundle <-
+ callr::r(
+ function(fit_model) {
+ mod <- fit_model()
+
+ bundle::bundle(mod)
+ },
+ args = list(fit_model = fit_model)
+ )
+
+ # pass the bundle to a new session, unbundle it, return predictions ----------
+ mod_unbundled_preds <-
+ callr::r(
+ function(mod_bundle, test_data) {
+ library(catboost)
+
+ mod_unbundled <- bundle::unbundle(mod_bundle)
+
+ test_pool <- catboost.load_pool(data = test_data)
+
+ catboost.predict(mod_unbundled, test_pool, prediction_type = "Class")
+ },
+ args = list(
+ mod_bundle = mod_bundle,
+ test_data = iris[, 1:4]
+ )
+ )
+
+ # pass fit fn to a new session, fit, butcher, bundle, return bundle ----------
+ mod_butchered_bundle <-
+ callr::r(
+ function(fit_model) {
+ mod <- fit_model()
+
+ bundle::bundle(butcher::butcher(mod))
+ },
+ args = list(fit_model = fit_model)
+ )
+
+ # pass the bundle to a new session, unbundle it, return predictions ----------
+ mod_butchered_unbundled_preds <-
+ callr::r(
+ function(mod_butchered_bundle, test_data) {
+ library(bundle)
+ library(catboost)
+
+ mod_butchered_unbundled <- unbundle(mod_butchered_bundle)
+
+ test_pool <- catboost.load_pool(data = test_data)
+
+ catboost.predict(
+ mod_butchered_unbundled,
+ test_pool,
+ prediction_type = "Class"
+ )
+ },
+ args = list(
+ mod_butchered_bundle = mod_butchered_bundle,
+ test_data = iris[, 1:4]
+ )
+ )
+
+ # run expectations -----------------------------------------------------------
+ mod_fit <- fit_model()
+ test_pool <- catboost::catboost.load_pool(data = iris[, 1:4])
+ mod_preds <- catboost::catboost.predict(
+ mod_fit,
+ test_pool,
+ prediction_type = "Class"
+ )
+
+ # check classes
+ expect_s3_class(mod_bundle, "bundled_catboost.Model")
+ expect_s3_class(unbundle(mod_bundle), "catboost.Model")
+
+ # ensure that the situater function didn't bring along the whole model
+ expect_false("x" %in% names(environment(mod_bundle$situate)))
+
+ # pass silly dots
+ expect_error(bundle(mod_fit, boop = "bop"), class = "rlib_error_dots")
+
+ # compare predictions
+ expect_equal(mod_preds, mod_unbundled_preds)
+ expect_equal(mod_preds, mod_butchered_unbundled_preds)
+
+ # verify metadata is preserved
+ mod_unbundled <- unbundle(mod_bundle)
+ expect_identical(mod_unbundled$tree_count, mod_fit$tree_count)
+ expect_identical(mod_unbundled$learning_rate, mod_fit$learning_rate)
+ expect_identical(mod_unbundled$feature_count, mod_fit$feature_count)
+})
+
+test_that("bundling + unbundling bonsai catboost model_fit", {
+ skip_if_not_installed("catboost")
+ skip_if_not_installed("bonsai")
+ skip_if_not_installed("parsnip")
+
+ # define a function to fit a model -------------------------------------------
+ fit_model <- function() {
+ library(parsnip)
+ library(bonsai)
+
+ model <- boost_tree(trees = 10) %>%
+ set_engine("catboost", verbose = 0, random_seed = 1) %>%
+ set_mode("classification") %>%
+ fit(Species ~ ., data = iris)
+
+ model
+ }
+
+ # pass fit fn to a new session, fit, bundle, and also get preds before bundle
+ result <-
+ callr::r(
+ function(fit_model, test_data) {
+ library(parsnip)
+ library(bonsai)
+
+ mod <- fit_model()
+
+ # Get predictions before bundling
+ preds_before <- predict(mod, test_data)
+
+ list(
+ bundle = bundle::bundle(mod),
+ preds_before = preds_before
+ )
+ },
+ args = list(fit_model = fit_model, test_data = iris)
+ )
+
+ mod_bundle <- result$bundle
+ mod_preds <- result$preds_before
+
+ # pass the bundle to a new session, unbundle it, return predictions ----------
+ mod_unbundled_preds <-
+ callr::r(
+ function(mod_bundle, test_data) {
+ library(parsnip)
+ library(bonsai)
+
+ mod_unbundled <- bundle::unbundle(mod_bundle)
+
+ predict(mod_unbundled, test_data)
+ },
+ args = list(
+ mod_bundle = mod_bundle,
+ test_data = iris
+ )
+ )
+
+ # run expectations -----------------------------------------------------------
+ # check classes
+ expect_s3_class(mod_bundle, "bundled_model_fit")
+
+ # compare predictions - the unbundled model should give same preds as original
+ expect_equal(mod_preds, mod_unbundled_preds)
+})