diff --git a/NAMESPACE b/NAMESPACE index c5dab6d..91cc64b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,7 @@ S3method(bundle,H2OBinomialModel) S3method(bundle,H2OMultinomialModel) S3method(bundle,H2ORegressionModel) S3method(bundle,bart) +S3method(bundle,catboost.Model) S3method(bundle,default) S3method(bundle,keras.engine.training.Model) S3method(bundle,luz_module_fitted) diff --git a/NEWS.md b/NEWS.md index f51fbeb..10206da 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # bundle (development version) +* Added bundle method for catboost models and, by extension, `parsnip::boost_tree(engine = "catboost")` via bonsai (#82). + # bundle 0.1.3 * Updated to support new versions of xgboost models (#75). diff --git a/R/bundle_catboost.R b/R/bundle_catboost.R new file mode 100644 index 0000000..3a92deb --- /dev/null +++ b/R/bundle_catboost.R @@ -0,0 +1,73 @@ +#' @templateVar class a `catboost.Model` +#' @template title_desc +#' +#' @templateVar outclass `bundled_catboost.Model` +#' @templateVar default . +#' @template return_bundle +#' @family bundlers +#' +#' @param x A `catboost.Model` object returned from `catboost::catboost.train()`. +#' @template param_unused_dots +#' @rdname bundle_catboost +#' @seealso This method stores the raw serialized model bytes from the +#' catboost model object and restores the C++ handle on unbundle. +#' @template butcher_details +#' @examplesIf rlang::is_installed(c("catboost", "parsnip", "bonsai")) +#' # fit model and bundle ------------------------------------------------ +#' library(parsnip) +#' library(bonsai) +#' +#' set.seed(1) +#' +#' mod <- boost_tree(trees = 10) %>% +#' set_engine("catboost", verbose = 0) %>% +#' set_mode("classification") %>% +#' +#' fit(Species ~ ., data = iris) +#' # extract the underlying catboost model +#' catboost_model <- mod$fit +#' +#' model_bundle <- bundle(catboost_model) +#' +#' # then, after saveRDS + readRDS or passing to a new session ---------- +#' model_unbundled <- unbundle(model_bundle) +#' @aliases bundle.catboost.Model +#' @method bundle catboost.Model +#' @export +bundle.catboost.Model <- function(x, ...) { + rlang::check_installed("catboost") + rlang::check_dots_empty() + + # Save model to temp file and read raw bytes + tmp_file <- withr::local_tempfile(fileext = ".cbm") + rlang::eval_tidy(rlang::call2( + "catboost.save_model", + x, + tmp_file, + .ns = "catboost" + )) + object <- readBin(tmp_file, "raw", file.size(tmp_file)) + + bundle_constr( + object = object, + situate = situate_constr(function(object) { + # Write raw bytes to temp file and load model + tmp_file <- withr::local_tempfile(fileext = ".cbm") + writeBin(object, tmp_file) + model <- rlang::eval_tidy(rlang::call2( + "catboost.load_model", + tmp_file, + .ns = "catboost" + )) + + # Restore metadata + model$feature_importances <- !!x$feature_importances + model$tree_count <- !!x$tree_count + model$learning_rate <- !!x$learning_rate + model$feature_count <- !!x$feature_count + + model + }), + desc_class = class(x)[1] + ) +} diff --git a/man/bundle.Rd b/man/bundle.Rd index 18e304c..7350db3 100644 --- a/man/bundle.Rd +++ b/man/bundle.Rd @@ -62,6 +62,7 @@ then re-loaded and \code{unbundle()}d in a new R session for use in prediction. Other bundlers: \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_bart.Rd b/man/bundle_bart.Rd index 545db0c..9174182 100644 --- a/man/bundle_bart.Rd +++ b/man/bundle_bart.Rd @@ -97,6 +97,7 @@ fit_unbundled_preds <- predict(fit_unbundled, mtcars) Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_caret.Rd b/man/bundle_caret.Rd index c4a4a50..82a3c12 100644 --- a/man/bundle_caret.Rd +++ b/man/bundle_caret.Rd @@ -110,6 +110,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_catboost.Rd b/man/bundle_catboost.Rd new file mode 100644 index 0000000..e32144b --- /dev/null +++ b/man/bundle_catboost.Rd @@ -0,0 +1,117 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bundle_catboost.R +\name{bundle.catboost.Model} +\alias{bundle.catboost.Model} +\title{Bundle a \code{catboost.Model} object} +\usage{ +\method{bundle}{catboost.Model}(x, ...) +} +\arguments{ +\item{x}{A \code{catboost.Model} object returned from \code{catboost::catboost.train()}.} + +\item{...}{Not used in this bundler and included for compatibility with +the generic only. Additional arguments passed to this method will return +an error.} +} +\value{ +A bundle object with subclass \code{bundled_catboost.Model}. + +Bundles are a list subclass with two components: + +\item{object}{An R object. Gives the output of native serialization +methods from the model-supplying package, sometimes with additional +classes or attributes that aid portability. This is often +a \link[base:raw]{raw} object.} +\item{situate}{A function. The \code{situate()} function is defined when +\code{\link[=bundle]{bundle()}} is called, though is a loose analogue of an \code{\link[=unbundle]{unbundle()}} S3 +method for that object. Since the function is defined on \code{\link[=bundle]{bundle()}}, it +has access to references and dependency information that can +be saved alongside the \code{object} component. Calling \code{\link[=unbundle]{unbundle()}} on a +bundled object \code{x} calls \code{x$situate(x$object)}, returning the +unserialized version of \code{object}. \code{situate()} will also restore needed +references, such as server instances and environmental variables.} + +Bundles are R objects that represent a "standalone" version of their +analogous model object. Thus, bundles are ready for saving to a file; saving +with \code{\link[base:readRDS]{base::saveRDS()}} is our recommended serialization strategy for bundles, +unless documented otherwise for a specific method. + +To restore the original model object \code{x} in a new environment, load its +bundle with \code{\link[base:readRDS]{base::readRDS()}} and run \code{\link[=unbundle]{unbundle()}} on it. The output +of \code{\link[=unbundle]{unbundle()}} is a model object that is ready to \code{\link[=predict]{predict()}} on new data, +and other restored functionality (like plotting or summarizing) is supported +as a side effect only. + +The bundle package wraps native serialization methods from model-supplying +packages. Between versions, those model-supplying packages may change their +native serialization methods, possibly introducing problems with re-loading +objects serialized with previous package versions. The bundle package does +not provide checks for these sorts of changes, and ought to be used in +conjunction with tooling for managing and monitoring model environments +like \link[vetiver:vetiver-package]{vetiver} or \link[renv:renv-package]{renv}. + +See \code{vignette("bundle")} for more information on bundling and its motivation. +} +\description{ +Bundling a model prepares it to be saved to a file and later +restored for prediction in a new R session. See the 'Value' section for +more information on bundles and their usage. +} +\section{bundle and butcher}{ + +The \href{https://butcher.tidymodels.org/}{butcher} package allows you to remove +parts of a fitted model object that are not needed for prediction. + +This bundle method is compatible with pre-butchering. That is, for a +fitted model \code{x}, you can safely call: + +\if{html}{\out{
}}\preformatted{res <- + x |> + butcher() |> + bundle() +}\if{html}{\out{
}} + +and predict with the output of \code{unbundle(res)} in a new R session. +} + +\examples{ +\dontshow{if (rlang::is_installed(c("catboost", "parsnip", "bonsai"))) withAutoprint(\{ # examplesIf} +# fit model and bundle ------------------------------------------------ +library(parsnip) +library(bonsai) + +set.seed(1) + +mod <- boost_tree(trees = 10) \%>\% + set_engine("catboost", verbose = 0) \%>\% + set_mode("classification") \%>\% + + fit(Species ~ ., data = iris) +# extract the underlying catboost model +catboost_model <- mod$fit + +model_bundle <- bundle(catboost_model) + +# then, after saveRDS + readRDS or passing to a new session ---------- +model_unbundled <- unbundle(model_bundle) +\dontshow{\}) # examplesIf} +} +\seealso{ +This method stores the raw serialized model bytes from the +catboost model object and restores the C++ handle on unbundle. + +Other bundlers: +\code{\link{bundle}()}, +\code{\link{bundle.H2OAutoML}()}, +\code{\link{bundle.bart}()}, +\code{\link{bundle.keras.engine.training.Model}()}, +\code{\link{bundle.luz_module_fitted}()}, +\code{\link{bundle.model_fit}()}, +\code{\link{bundle.model_stack}()}, +\code{\link{bundle.recipe}()}, +\code{\link{bundle.step_umap}()}, +\code{\link{bundle.train}()}, +\code{\link{bundle.workflow}()}, +\code{\link{bundle.xgb.Booster}()} +} +\concept{bundlers} diff --git a/man/bundle_embed.Rd b/man/bundle_embed.Rd index f4fab8d..f60689d 100644 --- a/man/bundle_embed.Rd +++ b/man/bundle_embed.Rd @@ -104,6 +104,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_h2o.Rd b/man/bundle_h2o.Rd index b62f97d..7095238 100644 --- a/man/bundle_h2o.Rd +++ b/man/bundle_h2o.Rd @@ -110,6 +110,7 @@ These methods wrap \code{\link[h2o:h2o.save_mojo]{h2o::h2o.save_mojo()}} and Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_keras.Rd b/man/bundle_keras.Rd index a0fc958..024ed8b 100644 --- a/man/bundle_keras.Rd +++ b/man/bundle_keras.Rd @@ -123,6 +123,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, \code{\link{bundle.model_stack}()}, diff --git a/man/bundle_parsnip.Rd b/man/bundle_parsnip.Rd index aa2f528..399152e 100644 --- a/man/bundle_parsnip.Rd +++ b/man/bundle_parsnip.Rd @@ -108,6 +108,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_stack}()}, diff --git a/man/bundle_recipe.Rd b/man/bundle_recipe.Rd index 480f879..5a0beec 100644 --- a/man/bundle_recipe.Rd +++ b/man/bundle_recipe.Rd @@ -69,6 +69,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_stacks.Rd b/man/bundle_stacks.Rd index 1c3981c..6ca0b38 100644 --- a/man/bundle_stacks.Rd +++ b/man/bundle_stacks.Rd @@ -88,6 +88,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_torch.Rd b/man/bundle_torch.Rd index f71d20a..dea62e3 100644 --- a/man/bundle_torch.Rd +++ b/man/bundle_torch.Rd @@ -151,6 +151,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.model_fit}()}, \code{\link{bundle.model_stack}()}, diff --git a/man/bundle_workflows.Rd b/man/bundle_workflows.Rd index a69497c..32b9789 100644 --- a/man/bundle_workflows.Rd +++ b/man/bundle_workflows.Rd @@ -115,6 +115,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/man/bundle_xgboost.Rd b/man/bundle_xgboost.Rd index b8ed244..161e9fb 100644 --- a/man/bundle_xgboost.Rd +++ b/man/bundle_xgboost.Rd @@ -111,6 +111,7 @@ Other bundlers: \code{\link{bundle}()}, \code{\link{bundle.H2OAutoML}()}, \code{\link{bundle.bart}()}, +\code{\link{bundle.catboost.Model}()}, \code{\link{bundle.keras.engine.training.Model}()}, \code{\link{bundle.luz_module_fitted}()}, \code{\link{bundle.model_fit}()}, diff --git a/tests/testthat/test_bundle_catboost.R b/tests/testthat/test_bundle_catboost.R new file mode 100644 index 0000000..2b06357 --- /dev/null +++ b/tests/testthat/test_bundle_catboost.R @@ -0,0 +1,180 @@ +test_that("bundling + unbundling catboost fits", { + skip_if_not_installed("catboost") + skip_if_not_installed("bonsai") + skip_if_not_installed("parsnip") + skip_if_not_installed("butcher") + + library(parsnip) + library(bonsai) + + # define a function to fit a model ------------------------------------------- + fit_model <- function() { + library(parsnip) + library(bonsai) + + model <- boost_tree(trees = 10) %>% + set_engine("catboost", verbose = 0, random_seed = 1) %>% + set_mode("classification") %>% + fit(Species ~ ., data = iris) + + model$fit + } + + # pass fit fn to a new session, fit, bundle, return bundle ------------------- + mod_bundle <- + callr::r( + function(fit_model) { + mod <- fit_model() + + bundle::bundle(mod) + }, + args = list(fit_model = fit_model) + ) + + # pass the bundle to a new session, unbundle it, return predictions ---------- + mod_unbundled_preds <- + callr::r( + function(mod_bundle, test_data) { + library(catboost) + + mod_unbundled <- bundle::unbundle(mod_bundle) + + test_pool <- catboost.load_pool(data = test_data) + + catboost.predict(mod_unbundled, test_pool, prediction_type = "Class") + }, + args = list( + mod_bundle = mod_bundle, + test_data = iris[, 1:4] + ) + ) + + # pass fit fn to a new session, fit, butcher, bundle, return bundle ---------- + mod_butchered_bundle <- + callr::r( + function(fit_model) { + mod <- fit_model() + + bundle::bundle(butcher::butcher(mod)) + }, + args = list(fit_model = fit_model) + ) + + # pass the bundle to a new session, unbundle it, return predictions ---------- + mod_butchered_unbundled_preds <- + callr::r( + function(mod_butchered_bundle, test_data) { + library(bundle) + library(catboost) + + mod_butchered_unbundled <- unbundle(mod_butchered_bundle) + + test_pool <- catboost.load_pool(data = test_data) + + catboost.predict( + mod_butchered_unbundled, + test_pool, + prediction_type = "Class" + ) + }, + args = list( + mod_butchered_bundle = mod_butchered_bundle, + test_data = iris[, 1:4] + ) + ) + + # run expectations ----------------------------------------------------------- + mod_fit <- fit_model() + test_pool <- catboost::catboost.load_pool(data = iris[, 1:4]) + mod_preds <- catboost::catboost.predict( + mod_fit, + test_pool, + prediction_type = "Class" + ) + + # check classes + expect_s3_class(mod_bundle, "bundled_catboost.Model") + expect_s3_class(unbundle(mod_bundle), "catboost.Model") + + # ensure that the situater function didn't bring along the whole model + expect_false("x" %in% names(environment(mod_bundle$situate))) + + # pass silly dots + expect_error(bundle(mod_fit, boop = "bop"), class = "rlib_error_dots") + + # compare predictions + expect_equal(mod_preds, mod_unbundled_preds) + expect_equal(mod_preds, mod_butchered_unbundled_preds) + + # verify metadata is preserved + mod_unbundled <- unbundle(mod_bundle) + expect_identical(mod_unbundled$tree_count, mod_fit$tree_count) + expect_identical(mod_unbundled$learning_rate, mod_fit$learning_rate) + expect_identical(mod_unbundled$feature_count, mod_fit$feature_count) +}) + +test_that("bundling + unbundling bonsai catboost model_fit", { + skip_if_not_installed("catboost") + skip_if_not_installed("bonsai") + skip_if_not_installed("parsnip") + + # define a function to fit a model ------------------------------------------- + fit_model <- function() { + library(parsnip) + library(bonsai) + + model <- boost_tree(trees = 10) %>% + set_engine("catboost", verbose = 0, random_seed = 1) %>% + set_mode("classification") %>% + fit(Species ~ ., data = iris) + + model + } + + # pass fit fn to a new session, fit, bundle, and also get preds before bundle + result <- + callr::r( + function(fit_model, test_data) { + library(parsnip) + library(bonsai) + + mod <- fit_model() + + # Get predictions before bundling + preds_before <- predict(mod, test_data) + + list( + bundle = bundle::bundle(mod), + preds_before = preds_before + ) + }, + args = list(fit_model = fit_model, test_data = iris) + ) + + mod_bundle <- result$bundle + mod_preds <- result$preds_before + + # pass the bundle to a new session, unbundle it, return predictions ---------- + mod_unbundled_preds <- + callr::r( + function(mod_bundle, test_data) { + library(parsnip) + library(bonsai) + + mod_unbundled <- bundle::unbundle(mod_bundle) + + predict(mod_unbundled, test_data) + }, + args = list( + mod_bundle = mod_bundle, + test_data = iris + ) + ) + + # run expectations ----------------------------------------------------------- + # check classes + expect_s3_class(mod_bundle, "bundled_model_fit") + + # compare predictions - the unbundled model should give same preds as original + expect_equal(mod_preds, mod_unbundled_preds) +})