diff --git a/NAMESPACE b/NAMESPACE index 5cd89a5..9ba4edd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -86,7 +86,6 @@ importFrom(ggplot2,theme) importFrom(ggplot2,xlab) importFrom(ggplot2,ylim) importFrom(glue,glue) -importFrom(graphics,barplot) importFrom(hardhat,tune) importFrom(parsnip,augment) importFrom(parsnip,boost_tree) @@ -120,7 +119,6 @@ importFrom(rsample,validation_set) importFrom(rsample,vfold_cv) importFrom(stats,coef) importFrom(stats,fisher.test) -importFrom(stats,reformulate) importFrom(stringr,str_remove) importFrom(stringr,str_split) importFrom(tibble,add_column) diff --git a/R/arg_check_ml.R b/R/arg_check_ml.R index 41f01c8..5efedee 100644 --- a/R/arg_check_ml.R +++ b/R/arg_check_ml.R @@ -251,7 +251,7 @@ NULL #' @noRd #' @keywords internal #' @param parsnip_mod A parsnip model object, such as the output of -#' `buildLRModel()` (random forest and boosted tree support planned) +#' `buildLRModel()` #' .checkArgParsnipMod <- function(parsnip_mod) { if (class(parsnip_mod)[2] != "model_spec") { diff --git a/R/core_ml.R b/R/core_ml.R index db874db..f325ea1 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -73,7 +73,8 @@ NULL #' @return An `rsplit` object #' @export splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280) { - .checkArgTibble(ml_input_tibble, ml = TRUE); .checkArgSplit(split) + .checkArgTibble(ml_input_tibble, ml = TRUE) + .checkArgSplit(split) .checkArgSeed(seed) set.seed(seed) @@ -85,7 +86,7 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 # If in CV mode: # Still retain a stratified testing holdout purely for final reporting metrics; # CV is only performed on the training portion. - prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test + prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test data_split <- rsample::initial_split( ml_input_tibble, prop = prop_train_for_holdout, @@ -115,7 +116,8 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 #' @return A `recipe` object #' @export buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { - .checkArgTibble(train_data, ml = TRUE); .checkArgUsePCA(use_pca) + .checkArgTibble(train_data, ml = TRUE) + .checkArgUsePCA(use_pca) .checkArgPCAThreshold(pca_threshold) target_var <- .getTargetVarName(train_data) |> as.character() @@ -124,8 +126,10 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { nm <- names(train_data) id_cols <- setdiff(nm[grepl("^genome", nm)], target_var) - rec <- recipes::recipe(formula = stats::reformulate(".", response = target_var), - data = train_data) + rec <- recipes::recipe( + formula = stats::reformulate(".", response = target_var), + data = train_data + ) # Only update roles if we actually have ID columns to mark as metadata if (length(id_cols) > 0) { @@ -146,7 +150,6 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { } - #' buildLRModel() #' #' Builds a logistic regression model. @@ -158,13 +161,17 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { buildLRModel <- function(multi_class = FALSE) { .checkArgMultiClass(multi_class) - if(!multi_class) { - lr_mod <- parsnip::logistic_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + if (!multi_class) { + lr_mod <- parsnip::logistic_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") - } else if(multi_class) { - lr_mod <- parsnip::multinom_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + } else if (multi_class) { + lr_mod <- parsnip::multinom_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") } @@ -176,14 +183,16 @@ buildLRModel <- function(multi_class = FALSE) { #' Builds a `tidymodels` workflow based on an input model and recipe. #' #' @param parsnip_mod A `parsnip` model object, such as the output of -#' `buildLRModel()` (random forest and boosted tree support planned) +#' `buildLRModel()` #' @param recipe A recipe, such as the output of `buildRecipe()` #' @return A `workflow` object #' @export buildWflow <- function(parsnip_mod, recipe) { - .checkArgParsnipMod(parsnip_mod); .checkArgRecipe(recipe) + .checkArgParsnipMod(parsnip_mod) + .checkArgRecipe(recipe) - wflow <- workflows::workflow() |> workflows::add_model(parsnip_mod) |> + wflow <- workflows::workflow() |> + workflows::add_model(parsnip_mod) |> workflows::add_recipe(recipe) return(wflow) @@ -195,29 +204,39 @@ buildWflow <- function(parsnip_mod, recipe) { #' #' @param model [chr] Currently, logistic regression ("LR") is supported. #' @param penalty_vec [num] A vector containing `penalty` (regularization -#' strength) values to try (for logistic regression). Recommended range: -#' 10^-4 to 10^4. +#' strength) values to try (for logistic regression). It is recommended to +#' choose values within a range of 10^-4 to 10^4. #' @param mix_vec [num] A vector containing `mixture` values to try for logistic #' regression. 0 corresponds to L2 regularization; 1 corresponds to L1; #' intermediate values (0, 1) correspond to elastic net. -#' @return A logistic regression tuning grid as a tibble +#' @param n_feat [num] Number of features in pangenome. Used to +#' calculate `mtry` values for a subsequent grid search (for random forest or +#' boosted tree). Output of `getNumFeat()`. +#' @param min_n_vec [num] A vector containing `min_n` values (the number of data +#' points in a node required for the node to be split) to try for random forest +#' or boosted tree. It is recommended to choose values within a range of 1 to 100. +#' @param tree_vec [num] A vector containing values to try for the number of +#' `trees` in random forest or boosted tree. It is recommended to choose values +#' within a range of 100 to 1000. +#' @return A logistic regression, random forest, or boosted tree tuning grid as +#' a tibble #' @export buildTuningGrid <- function( - model = "LR", - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5 + model = "LR", + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5 ) { .checkArgModel(model) - + if (model == "LR") { .checkArgPenaltyVec(penalty_vec) .checkArgMixVec(mix_vec) - + penalty <- rep(penalty_vec, each = length(mix_vec)) mixture <- rep(mix_vec, length(penalty_vec)) grid <- tibble::tibble(penalty, mixture) } - + return(grid) } @@ -237,13 +256,14 @@ buildTuningGrid <- function( #' @export tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), n_fold = 5) { - .checkArgTibble(grid); .checkArgWflow(wflow) + .checkArgTibble(grid) + .checkArgWflow(wflow) .checkArgDataSplit(data_split) split_class <- class(data_split)[1] # Always do CV on the training portion of the split - train_df <- rsample::training(data_split) + train_df <- rsample::training(data_split) target_var <- .getTargetVarName(train_df) if (identical(split_class, "initial_split")) { @@ -259,9 +279,9 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), tune_res <- tune::tune_grid( wflow, resamples = resamples, - grid = grid, - control = tune::control_grid(save_pred = TRUE), - metrics = yardstick::metric_set( + grid = grid, + control = tune::control_grid(save_pred = TRUE), + metrics = yardstick::metric_set( yardstick::f_meas, yardstick::pr_auc, yardstick::spec, @@ -286,7 +306,8 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), #' @return Best model workflow #' @export selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { - .checkArgTuneRes(tune_res); .checkArgWflow(wflow) + .checkArgTuneRes(tune_res) + .checkArgWflow(wflow) .checkArgSelectBestMetric(select_best_metric) best_mod <- tune::select_best(tune_res, metric = select_best_metric) @@ -306,7 +327,8 @@ selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { #' @return Best model fit #' @export fitBestModel <- function(final_mod, train_data) { - .checkArgWflow(final_mod); .checkArgTibble(train_data, ml = TRUE) + .checkArgWflow(final_mod) + .checkArgTibble(train_data, ml = TRUE) fit <- final_mod |> parsnip::fit(data = train_data) @@ -324,8 +346,7 @@ fitBestModel <- function(final_mod, train_data) { model <- class(fit$fit$actions$model$spec)[1] - if(model %in% c("logistic_reg", "multinom_reg")) { - + if (model %in% c("logistic_reg", "multinom_reg")) { penalty <- fit$fit$fit$spec$args$penalty mixture <- tryCatch( @@ -334,7 +355,6 @@ fitBestModel <- function(final_mod, train_data) { ) tibble::tibble(penalty = penalty, mixture = mixture) - } else { stop("The `fit` object provided must correspond to 'logistic_reg' or 'multinom_reg'.") } @@ -353,7 +373,8 @@ fitBestModel <- function(final_mod, train_data) { #' labels #' @export predictML <- function(fit, test_data) { - .checkArgWflow(fit); .checkArgTibble(test_data, ml = TRUE) + .checkArgWflow(fit) + .checkArgTibble(test_data, ml = TRUE) test_data_plus_predictions <- parsnip::augment(fit, test_data) @@ -396,7 +417,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) { mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() + dplyr::select(.estimate) |> + as.numeric() nmcc <- (mcc + 1) / 2 @@ -413,15 +435,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateF1 <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } f1 <- test_data_plus_predictions |> - yardstick::f_meas(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::f_meas( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(f1) @@ -437,16 +465,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateAUPRC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- test_data_plus_predictions |> yardstick::pr_auc( - truth = genome_drug.resistant_phenotype, .pred_Resistant) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, .pred_Resistant + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(auprc) } @@ -461,26 +494,33 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateLog2APOP <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- .calculateAUPRC(test_data_plus_predictions) prior <- sum( - test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant") / + test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant" + ) / nrow(test_data_plus_predictions) - if(prior > 0.3 && prior < 0.7) { - warning(paste("Classes are roughly balanced.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) - } else if(prior >= 0.7) { - warning(paste("Classes are imbalanced toward the resistant phenotype.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) + if (prior > 0.3 && prior < 0.7) { + warning(paste( + "Classes are roughly balanced.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) + } else if (prior >= 0.7) { + warning(paste( + "Classes are imbalanced toward the resistant phenotype.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) } - log2_apop <- log2(auprc/prior) |> round(2) + log2_apop <- log2(auprc / prior) |> round(2) return(log2_apop) } @@ -495,16 +535,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateBalAcc <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } bal_acc <- test_data_plus_predictions |> yardstick::bal_accuracy( - truth = genome_drug.resistant_phenotype, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(bal_acc) } @@ -519,15 +564,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSensitivity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } sens <- test_data_plus_predictions |> - yardstick::sens(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::sens( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(sens) @@ -543,15 +594,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSpecificity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } spec <- test_data_plus_predictions |> - yardstick::spec(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::spec( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(spec) @@ -598,30 +655,36 @@ calculateEvalMets <- function(test_data_plus_predictions) { #' `Importance`, and a column for `Sign` (or, for multi-class, a tibble with #' per-class columns of importance scores for each `Variable`) #' @export -extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), - n_top_feats = NA) { +extractTopFeats <- function( + fit, prop_vi_top_feats = c(0, 1), + n_top_feats = NA +) { .checkArgWflow(fit) - if(!is.na(n_top_feats)) {prop_vi_top_feats <- NA} + if (!is.na(n_top_feats)) { + prop_vi_top_feats <- NA + } # Arg checking for every permutation of `prop_vi_top_feats` and `n_top_feats` - if(is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { + if (is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { .checkArgPropVITopFeats(prop_vi_top_feats) - } else if(any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { .checkArgNTopFeats(n_top_feats) - } else if(any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { stop("Set either `n_top_feats` or `prop_vi_top_feats` to `NA` but not both.") - } else if(any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { stop("Please specify either `n_top_feats` or `prop_vi_top_feats`.") } - feats_arranged <- fit |> workflowsets::extract_fit_parsnip() |> vip::vi() |> + feats_arranged <- fit |> + workflowsets::extract_fit_parsnip() |> + vip::vi() |> dplyr::arrange(dplyr::desc(Importance)) - if(!is.na(n_top_feats)) { + if (!is.na(n_top_feats)) { top_feats_and_VIs <- feats_arranged |> dplyr::slice(1:n_top_feats) - } else if(any(!is.na(prop_vi_top_feats))) { + } else if (any(!is.na(prop_vi_top_feats))) { cum_vi_lower <- prop_vi_top_feats[1] * sum(feats_arranged$Importance) cum_vi_upper <- prop_vi_top_feats[2] * sum(feats_arranged$Importance) @@ -638,9 +701,11 @@ extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), # Take a different approach if using multi-class (the previous code would give # a less meaningful result). - if(class(fit$fit$actions$model$spec)[1] == "multinom_reg") { - warning(paste("Extracting top features from a multi-class model.", - "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply.")) + if (class(fit$fit$actions$model$spec)[1] == "multinom_reg") { + warning(paste( + "Extracting top features from a multi-class model.", + "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply." + )) fit_penalty <- .getFitHps(fit)["penalty"] |> as.numeric() glmnet_fit <- parsnip::extract_fit_engine(fit) diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index bb1dc68..b19beef 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -156,7 +156,6 @@ skipImbalancedMatrix <- function(genome_ids, split, stratify_by = NULL, verbosity = c("minimal", "debug")) { - verbosity <- match.arg(verbosity) log <- .make_logger(verbosity) @@ -197,8 +196,10 @@ skipImbalancedMatrix <- function(genome_ids, if (!dir.exists(matrix_path)) dir.create(matrix_path, recursive = TRUE) log("info", paste0("Matrix output directory: ", matrix_path)) - log("debug", paste0("Stratification: ", - ifelse(is.null(stratify_column), "None", stratify_column))) + log("debug", paste0( + "Stratification: ", + ifelse(is.null(stratify_column), "None", stratify_column) + )) # Feature and matrix types feature_types <- list( @@ -220,9 +221,11 @@ skipImbalancedMatrix <- function(genome_ids, # Safe DBI-quoting quote_condition <- function(group_cols, group_values, con) { - ids <- vapply(group_cols, - function(col) DBI::dbQuoteIdentifier(con, col), - character(1)) + ids <- vapply( + group_cols, + function(col) DBI::dbQuoteIdentifier(con, col), + character(1) + ) vals <- vapply( group_cols, function(col) { @@ -256,7 +259,6 @@ skipImbalancedMatrix <- function(genome_ids, log("debug", paste0("Found ", nrow(all_groups), " groups for type: ", group_type)) for (i in seq_len(nrow(all_groups))) { - # New connection for this group con <- DBI::dbConnect(duckdb::duckdb(), parquet_duckdb_path) @@ -268,13 +270,14 @@ skipImbalancedMatrix <- function(genome_ids, condition_string <- quote_condition(group_cols, group_values, con) # Strat filter - strat_filter <- if (!is.null(stratify_column)) + strat_filter <- if (!is.null(stratify_column)) { sprintf("AND \"%s\" IS NOT NULL AND \"%s\" != ''", stratify_column, stratify_column) - else "" + } else { + "" + } # Genome selection logic if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { - genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( SELECT \"genome_drug.genome_id\" AS genome_id, @@ -290,7 +293,6 @@ skipImbalancedMatrix <- function(genome_ids, FROM class_phenotypes WHERE any_resistant = 1 OR all_susceptible = 1 ", condition_string))[[1]] - } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" SELECT DISTINCT \"genome_drug.genome_id\" @@ -310,19 +312,24 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string)) phenotype_summary <- paste( - apply(phenotype_counts_all, 1, - function(row) paste0(row["phenotype"], "=", row["count"])), + apply( + phenotype_counts_all, 1, + function(row) paste0(row["phenotype"], "=", row["count"]) + ), collapse = "; " ) # Apply skip logic if (skipImbalancedMatrix(genome_ids, phenotype_counts_all, n_fold, split, - verbosity = verbosity)) { - + verbosity = verbosity + )) { readr::write_lines( - sprintf("%s\tToo few samples for CV/split\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few samples for CV/split\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -331,9 +338,12 @@ skipImbalancedMatrix <- function(genome_ids, if (length(genome_ids) < 40) { readr::write_lines( - sprintf("%s\tToo few observations\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few observations\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -351,9 +361,12 @@ skipImbalancedMatrix <- function(genome_ids, if (nrow(phen2) < 2) { readr::write_lines( - sprintf("%s\tOnly one phenotype class\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tOnly one phenotype class\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -363,13 +376,14 @@ skipImbalancedMatrix <- function(genome_ids, # Create selected_genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genome_ids), append = TRUE) + data.frame(genome_id = genome_ids), + append = TRUE + ) # Feature and matrix generation steps for (ftype in names(feature_types)) { - fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col # binary view DBI::dbExecute(con, sprintf(" @@ -389,13 +403,14 @@ skipImbalancedMatrix <- function(genome_ids, } for (mtype in names(matrix_types)) { - binary_only <- matrix_types[[mtype]]$binary_only if (ftype == "struct" && !binary_only) next - mview <- sprintf("%s_%s", ftype, - ifelse(grepl("binary", mtype), "binary", "counts")) - value_col <- matrix_types[[mtype]]$value_col + mview <- sprintf( + "%s_%s", ftype, + ifelse(grepl("binary", mtype), "binary", "counts") + ) + value_col <- matrix_types[[mtype]]$value_col filter_clause <- matrix_types[[mtype]]$filter # select features with non-zero variance @@ -409,29 +424,38 @@ skipImbalancedMatrix <- function(genome_ids, keep_features <- DBI::dbGetQuery(con, keep_query)[["feature_id"]] if (length(keep_features) == 0) { - log("info", paste0("All features filtered for ", - ftype, " - ", mtype, " - ", group_label)) + log("info", paste0( + "All features filtered for ", + ftype, " - ", mtype, " - ", group_label + )) next } - DBI::dbExecute(con, - "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") + DBI::dbExecute( + con, + "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)" + ) DBI::dbWriteTable(con, - "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + "keep_features", + data.frame(feature_id = keep_features), + append = TRUE + ) mtype_label <- matrix_types[[mtype]]$label - long_out_path <- file.path(matrix_path, - sprintf("%s_%s_%s_%s_%s_sparse.parquet", - bug, group_type, group_label, ftype, mtype_label)) + long_out_path <- file.path( + matrix_path, + sprintf( + "%s_%s_%s_%s_%s_sparse.parquet", + bug, group_type, group_label, ftype, mtype_label + ) + ) long_out_path_sql <- gsub("\\\\", "/", long_out_path) # phenotype case phenotype_case <- if (group_type %in% - c("drug_class", "drug_class_year", "drug_class_country")) { + c("drug_class", "drug_class_year", "drug_class_country")) { " CASE WHEN MAX(CASE WHEN f.\"genome_drug.resistant_phenotype\"='Resistant' @@ -451,13 +475,20 @@ skipImbalancedMatrix <- function(genome_ids, " } - strat_col_select <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_select <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - strat_col_group <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_group <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( SELECT f.\"genome_drug.genome_id\" AS genome_id, @@ -478,18 +509,21 @@ skipImbalancedMatrix <- function(genome_ids, TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') ", - fid, value_col, phenotype_case, strat_col_select, - mview, fid, condition_string, - strat_filter, fid, strat_col_group, fid, - long_out_path_sql) + fid, value_col, phenotype_case, strat_col_select, + mview, fid, condition_string, + strat_filter, fid, strat_col_group, fid, + long_out_path_sql + ) ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) # On copy failure, log + continue without stopping entire pipeline if (inherits(ok, "try-error")) { readr::write_lines( - sprintf("%s\tCOPY_failed\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), + sprintf( + "%s\tCOPY_failed\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), log_path, append = TRUE ) @@ -530,7 +564,7 @@ skipImbalancedMatrix <- function(genome_ids, # Normalize paths to forward slashes for consistency matrix_path <- gsub("\\\\", "/", file.path(path, paste0("matrix_", stratify_by))) - LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) + LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) if (!dir.exists(matrix_path)) { log("info", paste0("The matrix directory ", matrix_path, " does not exist.")) @@ -626,9 +660,11 @@ skipImbalancedMatrix <- function(genome_ids, out_file <- gsub("\\\\", "/", file.path( LOO_path, - paste0(sub_prefix, "_", stratify_by, "_", - drug_class, "_leaveout_", leave_one_out, "_", - sub_feature, "_sparse.parquet") + paste0( + sub_prefix, "_", stratify_by, "_", + drug_class, "_leaveout_", leave_one_out, "_", + sub_feature, "_sparse.parquet" + ) )) arrow::write_parquet(combined, out_file) created <<- c(created, out_file) @@ -702,7 +738,7 @@ skipImbalancedMatrix <- function(genome_ids, # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col for (mtype in names(matrix_types)) { binary_only <- matrix_types[[mtype]]$binary_only @@ -722,8 +758,9 @@ skipImbalancedMatrix <- function(genome_ids, # Selected genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genomes_to_keep), - append = TRUE) + data.frame(genome_id = genomes_to_keep), + append = TRUE + ) # Binary view DBI::dbExecute(con, sprintf(" @@ -763,13 +800,15 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") DBI::dbWriteTable(con, "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + data.frame(feature_id = keep_features), + append = TRUE + ) + - - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( - SELECT + SELECT f.\"genome_drug.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, @@ -779,26 +818,26 @@ skipImbalancedMatrix <- function(genome_ids, JOIN keep_features kf ON %s = kf.feature_id JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" WHERE resistant_classes <> 'Intermediate' - GROUP BY - f.\"genome_drug.genome_id\", - %s, + GROUP BY + f.\"genome_drug.genome_id\", + %s, resistant_classes - ORDER BY - f.\"genome_drug.genome_id\", + ORDER BY + f.\"genome_drug.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') - ", - fid, # %s -> feature_id expression column name - value_col, # %s -> value column to CAST - mview, # %s -> source view (binary or counts) - fid, # %s -> join to keep_features - fid, # %s -> group by feature id - fid, # %s -> order by feature id - out_file_sql # %s -> destination parquet file + ", + fid, # %s -> feature_id expression column name + value_col, # %s -> value column to CAST + mview, # %s -> source view (binary or counts) + fid, # %s -> join to keep_features + fid, # %s -> group by feature id + fid, # %s -> order by feature id + out_file_sql # %s -> destination parquet file ) - + ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) if (inherits(ok, "try-error")) { log("info", paste0("COPY failed for MDR matrix: ", out_file)) diff --git a/R/globals.R b/R/globals.R index 7652a41..6d2b741 100644 --- a/R/globals.R +++ b/R/globals.R @@ -2,7 +2,6 @@ # (non-standard evaluation) variables used with dplyr/tidyr/ggplot2 utils::globalVariables(c( - # Prediction columns from tidymodels ".estimate", ".pred_Resistant", @@ -42,7 +41,6 @@ utils::globalVariables(c( "pair_id", "parts", "phenotype", - "precision", "prefix", "prefix_key", diff --git a/R/merge_results.R b/R/merge_results.R new file mode 100644 index 0000000..90ce226 --- /dev/null +++ b/R/merge_results.R @@ -0,0 +1,644 @@ +## Consolidate results into Parquet outputs +#' Merge all *_performance.tsv into one table (plus metadata) and write Parquet inside results path +#' +#' - Uses createMLResultDir() to find the ML_performance directory under `path` +#' - Parses filenames using the same semantics as createMLinputList() +#' - Binds rows from all TSVs, adds parsed columns +#' - Writes a single Parquet file **inside** the ML_performance directory +#' +#' @param path Root results path (the same 'path' you pass to createMLResultDir) +#' @param stratify_by NULL | "year" | "country" +#' @param LOO logical; default FALSE +#' @param MDR logical; default FALSE +#' @param cross_test logical; default FALSE +#' @param out_parquet optional filename (no directories). If NULL, defaults to "all_performance.parquet". +#' If a path is given, only its basename is used; it is written in ML_performance/. +#' @param compression parquet compression ("zstd" or "snappy"); default "zstd" +#' @param verbose logical; print progress messages +#' @return A tibble with all performance rows + parsed metadata columns +buildPerfPq <- function( + path, + stratify_by = NULL, + LOO = FALSE, + MDR = FALSE, + cross_test = FALSE, + out_parquet = NULL, # only filename; will be written under ML_performance/ + compression = "zstd", + verbose = TRUE +) { + # ----------------------- + # Validate inputs + # ----------------------- + if (!is.character(path) || length(path) != 1 || is.na(path) || nchar(path) == 0) { + stop("`path` must be a non-empty character scalar.") + } + path <- normalizePath(path) + + if (!is.null(stratify_by) && !stratify_by %in% c("year", "country")) { + stop("`stratify_by` must be NULL, 'year', or 'country'.") + } + if (isTRUE(LOO) && is.null(stratify_by)) { + stop("With LOO=TRUE, stratify_by must be 'year' or 'country'.") + } + if (isTRUE(MDR) && (!is.null(stratify_by) || isTRUE(LOO) || isTRUE(cross_test))) { + stop("MDR can only run when stratify_by = NULL, LOO = FALSE, cross_test = FALSE.") + } + + # ----------------------- + # Resolve directories from your function (ensures they exist) + # ----------------------- + paths <- createMLResultDir(path, + stratify_by = stratify_by, LOO = LOO, + cross_test = cross_test, MDR = MDR + ) + perf_dir <- paths$ML_performance + + # ----------------------- + # Locate all performance TSVs + # ----------------------- + perf_files <- list.files(perf_dir, pattern = "performance\\.tsv$", full.names = TRUE, recursive = TRUE) + if (length(perf_files) == 0) { + if (verbose) message("No *_performance.tsv files found under: ", perf_dir) + return(tibble::tibble()) + } + + # ----------------------- + # Helpers + # ----------------------- + `%||%` <- function(a, b) if (!is.null(a)) a else b + .NA_chr <- function() NA_character_ + + .find_drug_label_value <- function(tokens) { + # Finds first "drug" and determines if it's "drug" or "drug_class" + idx <- which(tokens == "drug") + if (length(idx) == 0) { + return(list(drug_label = .NA_chr(), drug_value = .NA_chr(), label_end = NA_integer_)) + } + i <- idx[1] + if (i < length(tokens) && identical(tokens[i + 1], "class")) { + list( + drug_label = "drug_class", + drug_value = if (i + 2 <= length(tokens)) tokens[i + 2] else .NA_chr(), + label_end = i + 1 + ) + } else { + list( + drug_label = "drug", + drug_value = if (i + 1 <= length(tokens)) tokens[i + 1] else .NA_chr(), + label_end = i + ) + } + } + + .parse_base <- function(base_no_suffix) { + xs <- strsplit(base_no_suffix, "_", fixed = TRUE)[[1]] + n <- length(xs) + + # Initialize + species <- .NA_chr() + mdr_tag <- .NA_chr() + phenotype <- .NA_chr() + drug_label <- .NA_chr() + drug_or_class <- .NA_chr() + strat_label <- .NA_chr() + strat_value <- .NA_chr() + strat_value_test <- .NA_chr() + leaveout <- FALSE + is_cross <- FALSE + ref_drug <- .NA_chr() + test_drug <- .NA_chr() + prefix_key <- .NA_chr() + feature <- .NA_chr() + feature_type <- .NA_chr() + feature_subtype <- .NA_chr() + + # Require at least 2 tokens for feature + if (n >= 2) { + feature <- paste(xs[(n - 1):n], collapse = "_") + feature_type <- xs[n - 1] + feature_subtype <- xs[n] + } + core <- if ((n - 2) >= 1) xs[1:(n - 2)] else character(0) + core_str <- paste(core, collapse = "_") + + # MDR output prefixes are "MDR__" + if (length(core) > 0 && identical(core[1], "MDR")) { + mdr_tag <- "MDR" + if (length(core) >= 2) phenotype <- paste(core[-1], collapse = "_") + prefix_key <- "MDR" + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # Cross-test variants + if (grepl("_tested_on_", core_str, fixed = TRUE)) { + is_cross <- TRUE + # 1) LOO cross-test: ... __leaveout_tested_on__ + if (grepl("_leaveout_tested_on_", core_str, fixed = TRUE)) { + leaveout <- TRUE + di <- .find_drug_label_value(core) + drug_label <- di$drug_label + drug_or_class <- di$drug_value + label_end <- di$label_end + if (!is.na(label_end)) { + prefix_key <- paste(core[1:label_end], collapse = "_") + i_val <- label_end + 1 + # expect: value, leaveout, tested, on, strat_value + if ((i_val + 4) <= length(core) && + core[i_val + 1] == "leaveout" && core[i_val + 2] == "tested" && core[i_val + 3] == "on") { + strat_label <- stratify_by %||% .NA_chr() + strat_value <- core[i_val + 4] + } + } + species <- if (length(core) >= 1) core[1] else .NA_chr() + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # 2) Cross by strat group: ... __cross__tested_on__ + if (grepl("_cross_", core_str, fixed = TRUE)) { + di <- .find_drug_label_value(core) + drug_label <- di$drug_label + drug_or_class <- di$drug_value + label_end <- di$label_end + if (!is.na(label_end)) { + prefix_key <- paste(core[1:label_end], collapse = "_") + i_val <- label_end + 1 + # expect: value, cross, strat_value, tested, on, strat_value_test + if ((i_val + 5) <= length(core) && + core[i_val + 1] == "cross" && core[i_val + 3] == "tested" && core[i_val + 4] == "on") { + strat_label <- stratify_by %||% .NA_chr() + strat_value <- core[i_val + 2] + strat_value_test <- core[i_val + 5] + } + } + species <- if (length(core) >= 1) core[1] else .NA_chr() + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # 3) Cross-test (drug vs drug): ... __tested_on__ + di <- .find_drug_label_value(core) + label_end <- di$label_end + if (!is.na(label_end)) { + prefix_key <- paste(core[1:label_end], collapse = "_") + if ((label_end + 4) <= length(core) && + core[label_end + 2] == "tested" && core[label_end + 3] == "on") { + ref_drug <- core[label_end + 1] + test_drug <- core[label_end + 4] + drug_label <- if (endsWith(prefix_key, "drug_class")) "drug_class" else "drug" + drug_or_class <- ref_drug + } + } + species <- if (length(core) >= 1) core[1] else .NA_chr() + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # Non-cross variants (may be stratified) + species <- if (length(core) >= 1) core[1] else .NA_chr() + di <- .find_drug_label_value(core) + drug_label <- di$drug_label + label_end <- di$label_end + + if (!is.na(label_end)) { + # -------- STRATIFIED CASE -------- + # pattern: species drug[_class] strat_label drug_value strat_value + if (label_end + 3 <= length(core) && + core[label_end + 1] %in% c("year", "country")) { + strat_label <- core[label_end + 1] + drug_or_class <- core[label_end + 2] # <-- fixed: this is the FLQ/MAC/etc + strat_value <- core[label_end + 3] # <-- fixed: this is "2015-2019" + prefix_key <- paste(core[1:label_end], collapse = "_") + + # -------- UNSTRATIFIED CASE -------- + } else { + drug_or_class <- di$drug_value + prefix_key <- paste(core[1:label_end], collapse = "_") + } + } + + list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + ) + } + + # ----------------------- + # Read, parse and bind + # ----------------------- + out <- purrr::map_dfr(perf_files, function(f) { + base <- basename(f) + base_no_suffix <- sub("_performance\\.tsv$", "", base) + if (identical(base_no_suffix, base)) { + base_no_suffix <- sub("\\.tsv$", "", base) # fallback + } + + meta <- .parse_base(base_no_suffix) + + df <- tryCatch( + readr::read_tsv(f, show_col_types = FALSE, progress = FALSE), + error = function(e) { + if (verbose) message("Failed to read TSV: ", f, " (", conditionMessage(e), ") — using metadata only.") + tibble::tibble() + } + ) + + md_cols <- tibble::tibble( + output_prefix = base_no_suffix, + species = meta$species, + mdr_tag = meta$mdr_tag, + phenotype = meta$phenotype, + drug_label = meta$drug_label, + drug_or_class = meta$drug_or_class, + strat_label = meta$strat_label, + strat_value = meta$strat_value, + strat_value_test = meta$strat_value_test, + leaveout = meta$leaveout, + cross_test = meta$is_cross, + ref_drug = meta$ref_drug, + test_drug = meta$test_drug, + prefix_key = meta$prefix_key, + feature = meta$feature, + feature_type = meta$feature_type, + feature_subtype = meta$feature_subtype + ) + + if (nrow(df) == 0) { + md_cols + } else { + dplyr::bind_cols(md_cols[rep(1, nrow(df)), ], df) + } + }) + + # ----------------------- + # Compute the output Parquet path (ALWAYS inside perf_dir) + # ----------------------- + # If user gives a name/path, we keep only the basename and write under perf_dir. + parquet_name <- (out_parquet %||% "all_performance.parquet") + parquet_name <- basename(parquet_name) + out_path <- file.path(perf_dir, parquet_name) + + # ----------------------- + # Write Parquet + # ----------------------- + suppressPackageStartupMessages(library(arrow)) + arrow::write_parquet(out, out_path, compression = compression) + if (verbose) message("Wrote merged Parquet: ", out_path, " [", nrow(out), " rows]") + + out +} + + +#' Merge all *_top_features.tsv into one table + metadata, write Parquet inside results path +#' +#' - Uses createMLResultDir() to find the ML_top_features directory under `path` +#' - Parses filenames to derive metadata (aligned with createMLinputList() semantics) +#' - Binds rows from all top-features TSVs (keeps all original columns) +#' - Writes a single Parquet file **inside** the ML_top_features directory +#' +#' @param path Root results path (same `path` used for createMLResultDir) +#' @param stratify_by NULL | "year" | "country" +#' @param LOO logical; default FALSE +#' @param MDR logical; default FALSE +#' @param cross_test logical; default FALSE +#' @param out_parquet optional filename (no directories). If NULL, defaults to "all_top_features.parquet". +#' If a path is given, only its basename is used; the file is written in ML_top_features/. +#' @param compression parquet compression ("zstd" or "snappy"); default "zstd" +#' @param verbose logical; print progress messages +#' @return A tibble with all top-features rows + parsed metadata columns +buildTopFeatsPq <- function( + path, + stratify_by = NULL, + LOO = FALSE, + MDR = FALSE, + cross_test = FALSE, + out_parquet = NULL, # only filename; will be written under ML_top_features/ + compression = "zstd", + verbose = TRUE +) { + # ----------------------- + # Validate inputs + # ----------------------- + if (!is.character(path) || length(path) != 1 || is.na(path) || nchar(path) == 0) { + stop("`path` must be a non-empty character scalar.") + } + path <- normalizePath(path) + + if (!is.null(stratify_by) && !stratify_by %in% c("year", "country")) { + stop("`stratify_by` must be NULL, 'year', or 'country'.") + } + if (isTRUE(LOO) && is.null(stratify_by)) { + stop("With LOO=TRUE, stratify_by must be 'year' or 'country'.") + } + if (isTRUE(MDR) && (!is.null(stratify_by) || isTRUE(LOO) || isTRUE(cross_test))) { + stop("MDR can only run when stratify_by = NULL, LOO = FALSE, cross_test = FALSE.") + } + + # ----------------------- + # Resolve directories (ensures existence) + # ----------------------- + paths <- createMLResultDir(path, + stratify_by = stratify_by, LOO = LOO, + cross_test = cross_test, MDR = MDR + ) + top_dir <- paths$ML_top_features + + # ----------------------- + # Locate all top-features TSVs + # ----------------------- + top_files <- list.files(top_dir, pattern = "top_features\\.tsv$", full.names = TRUE, recursive = TRUE) + if (length(top_files) == 0) { + if (verbose) message("No *_top_features.tsv files found under: ", top_dir) + return(tibble::tibble()) + } + + # ----------------------- + # Helpers (shared with performance parser; includes stratified fix) + # ----------------------- + `%||%` <- function(a, b) if (!is.null(a)) a else b + .NA_chr <- function() NA_character_ + + .find_drug_label_value <- function(tokens) { + # Finds first "drug" and determines if it's "drug" or "drug_class" + idx <- which(tokens == "drug") + if (length(idx) == 0) { + return(list(drug_label = .NA_chr(), drug_value = .NA_chr(), label_end = NA_integer_)) + } + i <- idx[1] + if (i < length(tokens) && identical(tokens[i + 1], "class")) { + list( + drug_label = "drug_class", + drug_value = if (i + 2 <= length(tokens)) tokens[i + 2] else .NA_chr(), + label_end = i + 1 + ) + } else { + list( + drug_label = "drug", + drug_value = if (i + 1 <= length(tokens)) tokens[i + 1] else .NA_chr(), + label_end = i + ) + } + } + + .parse_base <- function(base_no_suffix) { + xs <- strsplit(base_no_suffix, "_", fixed = TRUE)[[1]] + n <- length(xs) + + # Initialize + species <- .NA_chr() + mdr_tag <- .NA_chr() + phenotype <- .NA_chr() + drug_label <- .NA_chr() + drug_or_class <- .NA_chr() + strat_label <- .NA_chr() + strat_value <- .NA_chr() + strat_value_test <- .NA_chr() + leaveout <- FALSE + is_cross <- FALSE + ref_drug <- .NA_chr() + test_drug <- .NA_chr() + prefix_key <- .NA_chr() + feature <- .NA_chr() + feature_type <- .NA_chr() + feature_subtype <- .NA_chr() + + # Feature from last 2 tokens + if (n >= 2) { + feature <- paste(xs[(n - 1):n], collapse = "_") + feature_type <- xs[n - 1] + feature_subtype <- xs[n] + } + core <- if ((n - 2) >= 1) xs[1:(n - 2)] else character(0) + core_str <- paste(core, collapse = "_") + + # MDR: "MDR__" + if (length(core) > 0 && identical(core[1], "MDR")) { + mdr_tag <- "MDR" + if (length(core) >= 2) phenotype <- paste(core[-1], collapse = "_") + prefix_key <- "MDR" + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # Cross-test variants + if (grepl("_tested_on_", core_str, fixed = TRUE)) { + is_cross <- TRUE + + # LOO cross-test: ... __leaveout_tested_on__ + if (grepl("_leaveout_tested_on_", core_str, fixed = TRUE)) { + leaveout <- TRUE + di <- .find_drug_label_value(core) + drug_label <- di$drug_label + drug_or_class <- di$drug_value + label_end <- di$label_end + if (!is.na(label_end)) { + prefix_key <- paste(core[1:label_end], collapse = "_") + i_val <- label_end + 1 + if ((i_val + 4) <= length(core) && + core[i_val + 1] == "leaveout" && core[i_val + 2] == "tested" && core[i_val + 3] == "on") { + strat_label <- stratify_by %||% .NA_chr() + strat_value <- core[i_val + 4] + } + } + species <- if (length(core) >= 1) core[1] else .NA_chr() + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # Cross by strat group: ... __cross__tested_on__ + if (grepl("_cross_", core_str, fixed = TRUE)) { + di <- .find_drug_label_value(core) + drug_label <- di$drug_label + drug_or_class <- di$drug_value + label_end <- di$label_end + if (!is.na(label_end)) { + prefix_key <- paste(core[1:label_end], collapse = "_") + i_val <- label_end + 1 + if ((i_val + 5) <= length(core) && + core[i_val + 1] == "cross" && core[i_val + 3] == "tested" && core[i_val + 4] == "on") { + strat_label <- stratify_by %||% .NA_chr() + strat_value <- core[i_val + 2] + strat_value_test <- core[i_val + 5] + } + } + species <- if (length(core) >= 1) core[1] else .NA_chr() + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # Cross-test (drug vs drug): ... __tested_on__ + di <- .find_drug_label_value(core) + label_end <- di$label_end + if (!is.na(label_end)) { + prefix_key <- paste(core[1:label_end], collapse = "_") + if ((label_end + 4) <= length(core) && + core[label_end + 2] == "tested" && core[label_end + 3] == "on") { + ref_drug <- core[label_end + 1] + test_drug <- core[label_end + 4] + drug_label <- if (endsWith(prefix_key, "drug_class")) "drug_class" else "drug" + drug_or_class <- ref_drug + } + } + species <- if (length(core) >= 1) core[1] else .NA_chr() + return(list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + )) + } + + # Non-cross variants (un/stratified), with stratified FIX + species <- if (length(core) >= 1) core[1] else .NA_chr() + di <- .find_drug_label_value(core) + drug_label <- di$drug_label + label_end <- di$label_end + + if (!is.na(label_end)) { + # STRATIFIED pattern (fix): species drug[_class] strat_label drug_value strat_value + if (label_end + 3 <= length(core) && core[label_end + 1] %in% c("year", "country")) { + strat_label <- core[label_end + 1] + drug_or_class <- core[label_end + 2] # e.g., FLQ, MAC, CIP, etc. + strat_value <- core[label_end + 3] # e.g., 2015-2019 or country name + prefix_key <- paste(core[1:label_end], collapse = "_") + } else { + # UNSTRATIFIED pattern: species drug[_class] drug_value + drug_or_class <- di$drug_value + prefix_key <- paste(core[1:label_end], collapse = "_") + } + } + + list( + species = species, mdr_tag = mdr_tag, phenotype = phenotype, + drug_label = drug_label, drug_or_class = drug_or_class, + strat_label = strat_label, strat_value = strat_value, strat_value_test = strat_value_test, + leaveout = leaveout, is_cross = is_cross, + ref_drug = ref_drug, test_drug = test_drug, + prefix_key = prefix_key, + feature = feature, feature_type = feature_type, feature_subtype = feature_subtype + ) + } + + # ----------------------- + # Read, parse and bind + # ----------------------- + out <- purrr::map_dfr(top_files, function(f) { + base <- basename(f) + # Accept either "..._top_features.tsv" (preferred) or ".tsv" fallback + base_no_suffix <- sub("_top_features\\.tsv$", "", base) + if (identical(base_no_suffix, base)) { + base_no_suffix <- sub("\\.tsv$", "", base) + } + + meta <- .parse_base(base_no_suffix) + + df <- tryCatch( + readr::read_tsv(f, show_col_types = FALSE, progress = FALSE), + error = function(e) { + if (verbose) message("Failed to read TSV: ", f, " (", conditionMessage(e), ") — using metadata only.") + tibble::tibble() + } + ) + + # Attach metadata columns + md_cols <- tibble::tibble( + output_prefix = base_no_suffix, + species = meta$species, + mdr_tag = meta$mdr_tag, + phenotype = meta$phenotype, + drug_label = meta$drug_label, + drug_or_class = meta$drug_or_class, + strat_label = meta$strat_label, + strat_value = meta$strat_value, + strat_value_test = meta$strat_value_test, + leaveout = meta$leaveout, + cross_test = meta$is_cross, + ref_drug = meta$ref_drug, + test_drug = meta$test_drug, + prefix_key = meta$prefix_key, + feature = meta$feature, + feature_type = meta$feature_type, + feature_subtype = meta$feature_subtype + ) + + if (nrow(df) == 0) { + md_cols + } else { + dplyr::bind_cols(md_cols[rep(1, nrow(df)), ], df) + } + }) + + # ----------------------- + # Compute the output Parquet path (ALWAYS inside top_dir) + # ----------------------- + parquet_name <- (out_parquet %||% "all_top_features.parquet") + parquet_name <- basename(parquet_name) + out_path <- file.path(top_dir, parquet_name) + + # ----------------------- + # Write Parquet + # ----------------------- + suppressPackageStartupMessages(library(arrow)) + arrow::write_parquet(out, out_path, compression = compression) + if (verbose) message("Wrote merged Parquet: ", out_path, " [", nrow(out), " rows]") + + out +} diff --git a/R/plot_ml.R b/R/plot_ml.R index 071e1a1..90121f3 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -214,7 +214,6 @@ plotFishers <- function( alpha = 0.05, label_top_n = 5 ) { - required_cols <- c("gene", "adj_p_value", "sig_after_bh") missing_cols <- setdiff(required_cols, colnames(fisher_df)) diff --git a/R/prep_ml.R b/R/prep_ml.R index cbdaf4f..e408721 100644 --- a/R/prep_ml.R +++ b/R/prep_ml.R @@ -111,8 +111,10 @@ loadMLInputTibble <- function(parquet_path) { if (exists(".ml_logger")) { log <- .ml_logger("minimal") - log("debug", paste0("ML tibble constructed: ", nrow(ml_input_tibble), - " genomes × ", getNumFeat(ml_input_tibble), " features")) + log("debug", paste0( + "ML tibble constructed: ", nrow(ml_input_tibble), + " genomes × ", getNumFeat(ml_input_tibble), " features" + )) } if (anyDuplicated(dplyr::pull(ml_input_tibble, genome_id)) != 0) { diff --git a/R/run_ML.R b/R/run_ML.R index eba37f8..2ed07e7 100644 --- a/R/run_ML.R +++ b/R/run_ML.R @@ -4,9 +4,11 @@ #' the ML matrices with these new split/CV values instead. #' @noRd .resolveSplitParams <- function(parquet_path, - defaults = list(split = c(0.8, 0), - seed = 5280, - n_fold = 5)) { + defaults = list( + split = c(0.8, 0), + seed = 5280, + n_fold = 5 + )) { # matrix_dir is the directory that contains the parquet files matrix_dir <- normalizePath(dirname(parquet_path)) params_json <- .readMLParameters(matrix_dir) @@ -16,8 +18,8 @@ } list( - split = if (!is.null(params_json$split)) params_json$split else defaults$split, - seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, + split = if (!is.null(params_json$split)) params_json$split else defaults$split, + seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, n_fold = if (!is.null(params_json$n_fold)) params_json$n_fold else defaults$n_fold ) } @@ -53,8 +55,9 @@ #' #' # LOO analysis stratified by year #' paths_loo <- createMLResultDir("/path/to/results", -#' stratify_by = "year", -#' LOO = TRUE) +#' stratify_by = "year", +#' LOO = TRUE +#' ) #' #' # MDR analysis #' paths_mdr <- createMLResultDir("/path/to/results", MDR = TRUE) @@ -90,16 +93,17 @@ createMLResultDir <- function(path, ) } else { # Determine prefixes (only in non-MDR mode) - full_prefix <- paste0(ifelse(isTRUE(LOO), "LOO_", ""), - ifelse(isTRUE(cross_test), "cross_test_", "")) + full_prefix <- paste0( + ifelse(isTRUE(LOO), "LOO_", ""), + ifelse(isTRUE(cross_test), "cross_test_", "") + ) half_prefix <- ifelse(isTRUE(LOO), "LOO_", "") # Determine suffix suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'country', or 'year'.") @@ -127,20 +131,20 @@ createMLResultDir <- function(path, return(paths) } - # createAllMLResultDir <- function(path) { - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # } - # +# createAllMLResultDir <- function(path) { +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# } +# #' Create machine learning input list #' @@ -174,8 +178,9 @@ createMLResultDir <- function(path, #' #' # Cross-test with year stratification #' inputs_ct <- createMLinputList("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE) +#' stratify_by = "year", +#' cross_test = TRUE +#' ) #' #' # MDR analysis #' inputs_mdr <- createMLinputList("/path/to/results", MDR = TRUE) @@ -187,10 +192,10 @@ createMLinputList <- function(path, LOO = FALSE, MDR = FALSE, cross_test = FALSE) { - # Validate inputs - if (!is.character(path) || length(path) != 1 || is.na(path)) + if (!is.character(path) || length(path) != 1 || is.na(path)) { stop("`path` must be a valid file path string.") + } path <- normalizePath(path) @@ -225,21 +230,17 @@ createMLinputList <- function(path, # Multi-drug resistance models # ============================ if (MDR) { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( parts = stringr::str_split(basename(ref_file), "_"), - species = purrr::map_chr(parts, ~ .x[1]), - mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" + mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" phenotype = purrr::map_chr(parts, ~ paste(.x[3:4], collapse = "_")), # Feature is 5th + 6th tokens feature_type = purrr::map_chr(parts, ~ .x[5]), feature_subtype = purrr::map_chr(parts, ~ stringr::str_remove(.x[6], "_sparse.parquet")), - feature = purrr::map2_chr(feature_type, feature_subtype, paste, sep = "_"), - output_prefix = paste0("MDR_", phenotype, "_", feature) ) @@ -247,38 +248,43 @@ createMLinputList <- function(path, dplyr::mutate( test_file = NA_character_, matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # For all other modeling types - # ============================ + # ============================ + # For all other modeling types + # ============================ } else { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( - parts = stringr::str_split(basename(ref_file), "_"), + parts = stringr::str_split(basename(ref_file), "_"), i_sparse = purrr::map_int(parts, ~ .get_idx(.x, "sparse.parquet")), - i_strat = purrr::map_int(parts, ~ { - if (is.null(stratify_by)) return(NA_integer_) + i_strat = purrr::map_int(parts, ~ { + if (is.null(stratify_by)) { + return(NA_integer_) + } .get_idx(.x, stratify_by) }), # Feature = last two tokens before sparse.parquet feature = purrr::map2_chr(parts, i_sparse, ~ { - i <- .y; x <- .x - if (is.na(i) || i < 3) return(NA_character_) + i <- .y + x <- .x + if (is.na(i) || i < 3) { + return(NA_character_) + } paste(x[(i - 2):(i - 1)], collapse = "_") }), # Drug or drug class extraction drug_or_class = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Stratified models if (!is.na(i)) { @@ -304,32 +310,40 @@ createMLinputList <- function(path, # Stratification value (if present) strat_value = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x - if (is.na(i)) return("") + i <- .y + x <- .x + if (is.na(i)) { + return("") + } # default position is two tokens after the strat label j <- i + 2 # if there's an intervening 'leaveout', skip over it if (j <= length(x) && identical(x[j], "leaveout")) j <- j + 1 - if (j <= length(x)) return(x[j]) - "" # no stratification + if (j <= length(x)) { + return(x[j]) + } + "" # no stratification }), # Prefix key for grouping prefix_key = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Case A: stratified -> prefix before the stratify label if (!is.na(i)) { - if (i - 1 >= 1) return(paste(x[1:(i - 1)], collapse = "_")) + if (i - 1 >= 1) { + return(paste(x[1:(i - 1)], collapse = "_")) + } return("") } # Case B: unstratified -> prefix is first two tokens - if (x[2] == "drug" && x[3] != "class"){ + if (x[2] == "drug" && x[3] != "class") { # Case A: Cje_drug_X return(paste(x[1:2], collapse = "_")) } - if (x[2] == "drug" && x[3] == "class"){ + if (x[2] == "drug" && x[3] == "class") { # Case A: Cje_drug_X return(paste(x[1:3], collapse = "_")) } @@ -345,18 +359,17 @@ createMLinputList <- function(path, test_file = NA_character_, output_prefix = gsub("_sparse\\.parquet$", "", basename(ref_file)), matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test modeling, no LOO - # ============================ + # ============================ + # Cross-test modeling, no LOO + # ============================ } else if (cross_test && !LOO) { - if (is.null(stratify_by)) { # Case A: stratify_by = NULL, pair across abx within same feature + prefix pairs <- parsed |> @@ -366,8 +379,10 @@ createMLinputList <- function(path, dplyr::select(test_file = ref_file, feature, prefix_key, strat_value, test_drug = drug_or_class), by = c("feature", "prefix_key", "strat_value") ) |> - dplyr::filter(ref_file != test_file, - ref_drug != test_drug) |> + dplyr::filter( + ref_file != test_file, + ref_drug != test_drug + ) |> dplyr::distinct() |> dplyr::mutate( output_prefix = paste0( @@ -380,10 +395,10 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -392,30 +407,29 @@ createMLinputList <- function(path, # Case B: stratify_by != NULL, pair same drug/class, prefix, feature, # but across different stratification groups pairs <- parsed |> - dplyr::select(ref_file, feature, prefix_key, strat_value, - drug_or_class) |> - + dplyr::select( + ref_file, feature, prefix_key, strat_value, + drug_or_class + ) |> # self-join ONLY on prefix_key, drug/class, feature dplyr::inner_join( parsed |> - dplyr::select(test_file = ref_file, - feature, prefix_key, strat_value_test = strat_value, - drug_or_class), + dplyr::select( + test_file = ref_file, + feature, prefix_key, strat_value_test = strat_value, + drug_or_class + ), by = c("prefix_key", "feature", "drug_or_class") ) |> - # do NOT test file against itself dplyr::filter(ref_file != test_file) |> - # enforce different stratification group dplyr::filter(strat_value != strat_value_test) |> - # remove symmetric duplicates (A,B == B,A) dplyr::rowwise() |> dplyr::mutate(pair_id = paste(sort(c(ref_file, test_file)), collapse = "||")) |> dplyr::ungroup() |> dplyr::distinct(pair_id, .keep_all = TRUE) |> - dplyr::mutate( output_prefix = paste0( prefix_key, "_", @@ -429,19 +443,18 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test + LOO modeling - # ============================ + # ============================ + # Cross-test + LOO modeling + # ============================ } else if (cross_test && LOO) { - # LOO requires special directory structure resolution test_path <- file.path(path, stringr::str_remove(basename(paths$matrix_path), "^LOO_")) test_path <- normalizePath(test_path) @@ -461,10 +474,10 @@ createMLinputList <- function(path, out <- loo_pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -472,9 +485,11 @@ createMLinputList <- function(path, } # If we ever get here, something wasn't covered - stop("Unhandled combination of arguments: ", - "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, - ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by) + stop( + "Unhandled combination of arguments: ", + "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, + ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by + ) } @@ -544,13 +559,15 @@ createMLinputList <- function(path, #' #' # Run with more threads and minimal output #' runMDRmodels("/path/to/results", -#' threads = 32, -#' verbose = FALSE) +#' threads = 32, +#' verbose = FALSE +#' ) #' #' # Run without saving model fits (save disk space) #' runMDRmodels("/path/to/results", -#' threads = 16, -#' return_fit = FALSE) +#' threads = 16, +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -571,12 +588,12 @@ runMDRmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - files <- createMLinputList(path, - stratify_by = NULL, - LOO = FALSE, - cross_test = FALSE, - MDR = TRUE) + stratify_by = NULL, + LOO = FALSE, + cross_test = FALSE, + MDR = TRUE + ) if (nrow(files) == 0) { message("No MDR files found to process. Exiting.") @@ -594,18 +611,19 @@ runMDRmodels <- function(path, # Auto tags for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMDRmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMDRmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -619,32 +637,37 @@ runMDRmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = NA, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = NA, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -652,19 +675,25 @@ runMDRmodels <- function(path, base <- paste0(shuffle_tag, output_prefix, pca_tag) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -783,21 +812,24 @@ runMDRmodels <- function(path, #' #' # Cross-test with year stratification #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE, -#' threads = 32) +#' stratify_by = "year", +#' cross_test = TRUE, +#' threads = 32 +#' ) #' #' # LOO analysis stratified by country with cross-testing #' runMLmodels("/path/to/results", -#' stratify_by = "country", -#' LOO = TRUE, -#' cross_test = TRUE, -#' verbose = TRUE) +#' stratify_by = "country", +#' LOO = TRUE, +#' cross_test = TRUE, +#' verbose = TRUE +#' ) #' #' # Run without saving model fits (save disk space) #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' return_fit = FALSE) +#' stratify_by = "year", +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -823,19 +855,21 @@ runMLmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - if (!is.null(stratify_by)) { - if (!is.character(stratify_by) || length(stratify_by) != 1L) + if (!is.character(stratify_by) || length(stratify_by) != 1L) { stop("`stratify_by` must be NULL or a single string: 'year' or 'country'.") - if (!stratify_by %in% c("year", "country")) + } + if (!stratify_by %in% c("year", "country")) { stop("`stratify_by` must be NULL, 'year', or 'country'.") + } } files <- createMLinputList(path, - stratify_by = stratify_by, - LOO = LOO, - MDR = FALSE, - cross_test = cross_test) + stratify_by = stratify_by, + LOO = LOO, + MDR = FALSE, + cross_test = cross_test + ) if (nrow(files) == 0) { message("No files found to process. Exiting.") @@ -864,8 +898,7 @@ runMLmodels <- function(path, strat_suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'year', or 'country'.") @@ -874,18 +907,19 @@ runMLmodels <- function(path, # Auto naming for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMLmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMLmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -910,32 +944,37 @@ runMLmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = test_data, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = test_data, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -943,19 +982,25 @@ runMLmodels <- function(path, base <- paste0(shuffle_tag, config_prefix, output_prefix, pca_tag, strat_suffix) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -973,7 +1018,6 @@ runMLmodels <- function(path, } - #' Run the entire AMR ML pipeline from a parquet-backed DuckDB #' #' This function provides a complete end-to-end AMR machine learning workflow. @@ -1006,11 +1050,12 @@ runModelingPipeline <- function(parquet_duckdb_path, pca_threshold = 0.99, verbose = TRUE, use_saved_split = TRUE) { - parquet_duckdb_path <- normalizePath(parquet_duckdb_path) if (!file.exists(parquet_duckdb_path)) { - stop("Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", - "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`") + stop( + "Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", + "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`" + ) } out_root <- dirname(parquet_duckdb_path) @@ -1024,9 +1069,9 @@ runModelingPipeline <- function(parquet_duckdb_path, generateMLInputs( parquet_duckdb_path = parquet_duckdb_path, out_path = out_root, - n_fold = n_fold, - split = split, - min_n = min_n, + n_fold = n_fold, + split = split, + min_n = min_n, verbosity = if (verbose) "minimal" else "debug" ) @@ -1089,12 +1134,13 @@ runModelingPipeline <- function(parquet_duckdb_path, # All done! if (verbose) { message("\n=== AMR-ML Pipeline Complete ===") - message("All matrices, models, top feature lists, and performance metrics saved under:\n ", - out_root) + message( + "All matrices, models, top feature lists, and performance metrics saved under:\n ", + out_root + ) message("\nTo inspect model outputs, see directories such as:") message(" ML_performance/, ML_models/, ML_prediction/, ML_top_features/") } invisible(out_root) } - diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index 2a97c00..a6bb78d 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -46,7 +46,7 @@ NULL #' principle components account #' @param penalty_vec [num] A vector containing `penalty` (regularization #' strength) values to try (for logistic regression). It is recommended to -#' choose values `10^-4` to `10^4`. +#' choose values within a range of 10^-4 to 10^4. #' @param mix_vec [num] A vector containing `mixture` values to try for logistic #' regression. 0 corresponds to L2 regularization; 1 corresponds to L1; #' intermediate values correspond to elastic net. @@ -93,20 +93,21 @@ runMLPipeline <- function( .checkArgReturnPred(return_pred) - # Set `n_fold` to `NA` if not using cross-validation. if (split[2] != 0) { n_fold <- NA } # Confirm resolved split params - if (verbose) { - mode <- if (split[2] == 0) "cv" else "splits" - message(sprintf("ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", - mode, split[1], split[2], - ifelse(is.na(n_fold), "NA", as.character(n_fold)), - as.character(seed))) - } + if (verbose) { + mode <- if (split[2] == 0) "cv" else "splits" + message(sprintf( + "ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", + mode, split[1], split[2], + ifelse(is.na(n_fold), "NA", as.character(n_fold)), + as.character(seed) + )) + } # Create a variable indicating whether external `test_data` was provided. This # will be set to `TRUE` later if the `test_data` argument is not `NA`. @@ -116,10 +117,10 @@ runMLPipeline <- function( # Determine whether multi-class classification is to be performed. if (as.character(.getTargetVarName(ml_input_tibble)) == "resistant_classes") { - multi_class <- TRUE - } else { - multi_class <- FALSE - } + multi_class <- TRUE + } else { + multi_class <- FALSE + } if (model != "LR" & multi_class) { stop(paste( @@ -262,7 +263,7 @@ runMLPipeline <- function( mix_vec = mix_vec ) } - + recipe <- buildRecipe(train_data, use_pca = use_pca, pca_threshold = pca_threshold @@ -421,14 +422,16 @@ runMLPipeline <- function( all_results[["fit"]] <- fit } - if(return_pred) { - if(!multi_class){ + if (return_pred) { + if (!multi_class) { all_results[["pred"]] <- test_data_plus_predictions |> - dplyr::select(c(genome_id, .pred_class, .pred_Resistant, - .pred_Susceptible, genome_drug.resistant_phenotype)) - } - all_results[["pred"]] <- test_data_plus_predictions + dplyr::select(c( + genome_id, .pred_class, .pred_Resistant, + .pred_Susceptible, genome_drug.resistant_phenotype + )) } + all_results[["pred"]] <- test_data_plus_predictions + } return(all_results) } diff --git a/man/buildPerfPq.Rd b/man/buildPerfPq.Rd new file mode 100644 index 0000000..2506ee0 --- /dev/null +++ b/man/buildPerfPq.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/merge_results.R +\name{buildPerfPq} +\alias{buildPerfPq} +\title{Merge all *_performance.tsv into one table (plus metadata) and write Parquet inside results path} +\usage{ +buildPerfPq( + path, + stratify_by = NULL, + LOO = FALSE, + MDR = FALSE, + cross_test = FALSE, + out_parquet = NULL, + compression = "zstd", + verbose = TRUE +) +} +\arguments{ +\item{path}{Root results path (the same 'path' you pass to createMLResultDir)} + +\item{stratify_by}{NULL | "year" | "country"} + +\item{LOO}{logical; default FALSE} + +\item{MDR}{logical; default FALSE} + +\item{cross_test}{logical; default FALSE} + +\item{out_parquet}{optional filename (no directories). If NULL, defaults to "all_performance.parquet". +If a path is given, only its basename is used; it is written in ML_performance/.} + +\item{compression}{parquet compression ("zstd" or "snappy"); default "zstd"} + +\item{verbose}{logical; print progress messages} +} +\value{ +A tibble with all performance rows + parsed metadata columns +} +\description{ +\itemize{ +\item Uses createMLResultDir() to find the ML_performance directory under \code{path} +\item Parses filenames using the same semantics as createMLinputList() +\item Binds rows from all TSVs, adds parsed columns +\item Writes a single Parquet file \strong{inside} the ML_performance directory +} +} diff --git a/man/buildTopFeatsPq.Rd b/man/buildTopFeatsPq.Rd new file mode 100644 index 0000000..14da3b5 --- /dev/null +++ b/man/buildTopFeatsPq.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/merge_results.R +\name{buildTopFeatsPq} +\alias{buildTopFeatsPq} +\title{Merge all *_top_features.tsv into one table + metadata, write Parquet inside results path} +\usage{ +buildTopFeatsPq( + path, + stratify_by = NULL, + LOO = FALSE, + MDR = FALSE, + cross_test = FALSE, + out_parquet = NULL, + compression = "zstd", + verbose = TRUE +) +} +\arguments{ +\item{path}{Root results path (same \code{path} used for createMLResultDir)} + +\item{stratify_by}{NULL | "year" | "country"} + +\item{LOO}{logical; default FALSE} + +\item{MDR}{logical; default FALSE} + +\item{cross_test}{logical; default FALSE} + +\item{out_parquet}{optional filename (no directories). If NULL, defaults to "all_top_features.parquet". +If a path is given, only its basename is used; the file is written in ML_top_features/.} + +\item{compression}{parquet compression ("zstd" or "snappy"); default "zstd"} + +\item{verbose}{logical; print progress messages} +} +\value{ +A tibble with all top-features rows + parsed metadata columns +} +\description{ +\itemize{ +\item Uses createMLResultDir() to find the ML_top_features directory under \code{path} +\item Parses filenames to derive metadata (aligned with createMLinputList() semantics) +\item Binds rows from all top-features TSVs (keeps all original columns) +\item Writes a single Parquet file \strong{inside} the ML_top_features directory +} +} diff --git a/man/buildTuningGrid.Rd b/man/buildTuningGrid.Rd index 498f310..c6ee983 100644 --- a/man/buildTuningGrid.Rd +++ b/man/buildTuningGrid.Rd @@ -14,8 +14,8 @@ buildTuningGrid( \item{model}{\link[rlang:vector-construction]{rlang::chr} Currently, logistic regression ("LR") is supported.} \item{penalty_vec}{\link[pillar:num]{pillar::num} A vector containing \code{penalty} (regularization -strength) values to try (for logistic regression). Recommended range: -10^-4 to 10^4.} +strength) values to try (for logistic regression). It is recommended to +choose values within a range of 10^-4 to 10^4.} \item{mix_vec}{\link[pillar:num]{pillar::num} A vector containing \code{mixture} values to try for logistic regression. 0 corresponds to L2 regularization; 1 corresponds to L1; diff --git a/man/buildWflow.Rd b/man/buildWflow.Rd index 1bf7932..7ce915a 100644 --- a/man/buildWflow.Rd +++ b/man/buildWflow.Rd @@ -8,7 +8,7 @@ buildWflow(parsnip_mod, recipe) } \arguments{ \item{parsnip_mod}{A \code{parsnip} model object, such as the output of -\code{buildLRModel()} (random forest and boosted tree support planned)} +\code{buildLRModel()}} \item{recipe}{A recipe, such as the output of \code{buildRecipe()}} } diff --git a/man/calculateSensitivity.Rd b/man/calculateSensitivity.Rd new file mode 100644 index 0000000..d55d54b --- /dev/null +++ b/man/calculateSensitivity.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/core_ml.R +\name{calculateSensitivity} +\alias{calculateSensitivity} +\title{calculateSensitivity()} +\usage{ +calculateSensitivity(test_data_plus_predictions) +} +\arguments{ +\item{test_data_plus_predictions}{Test data (tibble) with an added column for +predicted phenotype labels, such as the output of \code{predict()}} +} +\value{ +sensitivity +} +\description{ +Returns the sensitivity based on the AMR phenotype predictions by an ML model +compared against the actual values. +} diff --git a/man/calculateSpecificity.Rd b/man/calculateSpecificity.Rd new file mode 100644 index 0000000..7b38552 --- /dev/null +++ b/man/calculateSpecificity.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/core_ml.R +\name{calculateSpecificity} +\alias{calculateSpecificity} +\title{calculateSpecificity()} +\usage{ +calculateSpecificity(test_data_plus_predictions) +} +\arguments{ +\item{test_data_plus_predictions}{Test data (tibble) with an added column for +predicted phenotype labels, such as the output of \code{predict()}} +} +\value{ +specificity +} +\description{ +Returns the specificity score based on the AMR phenotype predictions by an ML model +compared against the actual values. +} diff --git a/man/runMLPipeline.Rd b/man/runMLPipeline.Rd index 31c542e..16cf04d 100644 --- a/man/runMLPipeline.Rd +++ b/man/runMLPipeline.Rd @@ -15,8 +15,6 @@ runMLPipeline( pca_threshold = 0.95, penalty_vec = 10^seq(-4, -1, length.out = 10), mix_vec = 0:5/5, - min_n_vec = c(2, 6, 12), - tree_vec = c(100, 500, 1000), select_best_metric = "mcc", seed = 123, shuffle_labels = FALSE, @@ -35,8 +33,7 @@ classification for one bug/drug combination) or \code{resistant_classes} (multi-class classification for determining the drug classes to which each genome is resistant), but not both.} -\item{model}{\link[rlang:vector-construction]{rlang::chr} Logistic regression ("LR"), random forest ("RF"), or -boosted tree ("BT")} +\item{model}{\link[rlang:vector-construction]{rlang::chr} Logistic regression ("LR")} \item{split}{\link[pillar:num]{pillar::num} Vector of length 2 indicating the proportion of data to be designated as training and validation, respectively. Note: if \code{test_data} @@ -64,22 +61,14 @@ principle components account} \item{penalty_vec}{\link[pillar:num]{pillar::num} A vector containing \code{penalty} (regularization strength) values to try (for logistic regression). It is recommended to -choose values \code{10^-4} to \code{10^4}.} +choose values within a range of 10^-4 to 10^4.} \item{mix_vec}{\link[pillar:num]{pillar::num} A vector containing \code{mixture} values to try for logistic regression. 0 corresponds to L2 regularization; 1 corresponds to L1; intermediate values correspond to elastic net.} -\item{min_n_vec}{[num] A vector containing \code{min_n} values (the number of data -points in a node required for the node to be split) to try for random forest -or boosted tree. It is recommended to choose values in the range 1 to 100.} - -\item{tree_vec}{[num] A vector containing values to try for the number of -\code{trees} in random forest or boosted tree. It is recommended to choose values -in the range 100 to 1000.} - \item{select_best_metric}{\link[rlang:vector-construction]{rlang::chr} Metric to select best model: "f_meas", -"pr_auc", or "bal_accuracy"} +"pr_auc", "mcc", or "bal_accuracy"} \item{seed}{\link[pillar:num]{pillar::num} For reproducible analysis} diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index 996eb6b..af5bc8e 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -264,19 +264,19 @@ ml_tibble_reduced <- removeTopFeats(ml_tibble, top_features) ### Precision-recall curve ```{r plot-prc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotPRC(test_data_plus_predictions) ``` ### ROC curve ```{r plot-roc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotROC(test_data_plus_predictions) ``` ### Variable importance plot ```{r plot-vi} -topfeat <- readr::read_tsv(results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv) +topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) plotTopFeatsVI(topfeat) ``` ### Baseline comparison barplot @@ -326,7 +326,6 @@ You can label the top N features to highlight the strongest hits (default is 5) ```{r} plotFishers(fisher_results) plotFishers(fisher_results, alpha = 0.01, label_top_n = 5) - ``` ## Wrapper to run all models @@ -338,14 +337,15 @@ Given a DuckDB file produced by `runDataProcessing()`, it: 5. saves performance metrics, fitted models, predictions, and top feature rankings ``` {r} runModelingPipeline(parquet_duckdb_path, - threads = 16, - n_fold = 5, - split = c(1, 0), - min_n = 25, - prop_vi_top_feats = c(0, 1), - pca_threshold = 0.99, - verbose = TRUE, - use_saved_split = TRUE) + threads = 16, + n_fold = 5, + split = c(1, 0), + min_n = 25, + prop_vi_top_feats = c(0, 1), + pca_threshold = 0.99, + verbose = TRUE, + use_saved_split = TRUE +) ``` Merge the performance and top features of each kind of models into a parquet that will serve as starting data for `amRshiny` package @@ -357,7 +357,7 @@ buildPerformancePq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE ) @@ -367,8 +367,8 @@ buildTopFeatsPq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE -) +) ```