From 278442b4f5187a321cc255e1eed6453570306f3c Mon Sep 17 00:00:00 2001 From: Santtu Tikka Date: Mon, 13 May 2024 13:33:44 +0300 Subject: [PATCH] mcmc_diagnostics, hmc_diagnostics and lfo as methods, cutpoints to cutpoint, run-extended --- NAMESPACE | 3 +++ R/as_data_frame.R | 2 +- R/as_data_table.R | 16 ++++++++-------- R/lfo.R | 18 ++++++++++++------ R/mcmc_diagnostics.R | 22 +++++++++++++++++++--- R/plot.R | 4 ++-- R/predict_helpers.R | 28 ++++++++++++++-------------- R/prepare_stan_input.R | 8 ++++---- R/stanblocks_families.R | 32 +++++++++++++++++--------------- man/as.data.frame.dynamitefit.Rd | 2 +- man/hmc_diagnostics.Rd | 7 ++++++- man/lfo.Rd | 13 ++++++++----- man/mcmc_diagnostics.Rd | 9 +++++++-- man/plot.lfo.Rd | 4 ++-- man/print.lfo.Rd | 2 +- tests/testthat/test-errors.R | 27 +++++++-------------------- tests/testthat/test-extended.R | 4 ++-- 17 files changed, 114 insertions(+), 87 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index f5dcc16..131f8da 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -19,7 +19,10 @@ S3method(get_parameter_names,dynamitefit) S3method(get_parameter_types,dynamitefit) S3method(get_priors,dynamitefit) S3method(get_priors,dynamiteformula) +S3method(hmc_diagnostics,dynamitefit) +S3method(lfo,dynamitefit) S3method(loo,dynamitefit) +S3method(mcmc_diagnostics,dynamitefit) S3method(ndraws,dynamitefit) S3method(nobs,dynamitefit) S3method(plot,dynamitefit) diff --git a/R/as_data_frame.R b/R/as_data_frame.R index 2c51ac6..8dc7b0b 100644 --- a/R/as_data_frame.R +++ b/R/as_data_frame.R @@ -11,7 +11,7 @@ #' #' * `alpha`\cr Intercept terms (time-invariant or time-varying). #' * `beta`\cr Time-invariant regression coefficients. -#' * `cutpoints`\cr Cutpoints for ordinal regression. +#' * `cutpoint`\cr Cutpoints for ordinal regression. #' * `delta`\cr Time-varying regression coefficients. #' * `nu`\cr Group-level random effects. #' * `lambda`\cr Factor loadings. diff --git a/R/as_data_table.R b/R/as_data_table.R index 078eab1..0907aec 100644 --- a/R/as_data_table.R +++ b/R/as_data_table.R @@ -706,9 +706,9 @@ as_data_table_corr <- function(x, draws, n_draws, resps, ...) { ) } -#' @describeIn as_data_table_default Data Table for a "cutpoints" Parameter +#' @describeIn as_data_table_default Data Table for a "cutpoint" Parameter #' @noRd -as_data_table_cutpoints <- function(x, draws, response, +as_data_table_cutpoint <- function(x, draws, response, n_draws, include_fixed, ...) { channel <- get_channel(x, response) S <- channel$S @@ -726,7 +726,7 @@ as_data_table_cutpoints <- function(x, draws, response, data.table::rbindlist(lapply(seq_len(S - 1L), function(i) { idx <- (i - 1L) * n_time2 + seq_len(n_time2) data.table::data.table( - parameter = paste0("cutpoints_", response), + parameter = paste0("cutpoint_", response), value = c( rep(NA, n_na), c(draws[, , idx]) @@ -737,7 +737,7 @@ as_data_table_cutpoints <- function(x, draws, response, })) } else { data.table::data.table( - parameter = paste0("cutpoints_", response), + parameter = paste0("cutpoint_", response), category = rep(seq_len(S - 1L), each = n_draws), value = c(draws) ) @@ -752,7 +752,7 @@ all_types <- c( "corr", "corr_nu", "corr_psi", - "cutpoints", + "cutpoint", "delta", "lambda", "nu", @@ -776,7 +776,7 @@ fixed_types <- c( "corr", "corr_nu", "corr_psi", - "cutpoints", + "cutpoint", "lambda", "nu", "omega", @@ -794,7 +794,7 @@ fixed_types <- c( varying_types <- c( "alpha", - "cutpoints", + "cutpoint", "delta", "psi" ) @@ -802,7 +802,7 @@ varying_types <- c( default_types <- c( "alpha", "beta", - "cutpoints", + "cutpoint", "delta", "lambda", "nu", diff --git a/R/lfo.R b/R/lfo.R index a279d20..39dc844 100644 --- a/R/lfo.R +++ b/R/lfo.R @@ -9,9 +9,8 @@ #' is triggered if pareto k values of any group exceeds the threshold. #' #' @export -#' @export lfo #' @family diagnostics -#' @aliases lfo +#' @rdname lfo #' @param x \[`dynamitefit`]\cr The model fit object. #' @param L \[`integer(1)`]\cr Positive integer defining how many time points #' should be used for the initial fit. @@ -51,7 +50,13 @@ #' } #' } #' -lfo <- function(x, L, verbose = TRUE, k_threshold = 0.7, ...) { +lfo <- function(x, ...) { + UseMethod("lfo", x) +} + +#' @export +#' @rdname lfo +lfo.dynamitefit <- function(x, L, verbose = TRUE, k_threshold = 0.7, ...) { stopifnot_( !missing(x), "Argument {.arg x} is missing." @@ -272,7 +277,8 @@ lfo <- function(x, L, verbose = TRUE, k_threshold = 0.7, ...) { #' Print the results from the LFO #' #' Prints the summary of the leave-future-out cross-validation. -#' @param x x \[`lfo`]\cr Output of the `lfo` method. +#' +#' @param x \[`lfo`]\cr Output of the `lfo` method. #' @param ... Ignored. #' @return Returns `x` invisibly. #' @export @@ -305,9 +311,9 @@ print.lfo <- function(x, ...) { #' Plots Pareto k values per each time point (with one point per group), #' together with a horizontal line representing the used threshold. #' -#' @param x \[`lfo`]\cr Output from the `lfo` function. +#' @param x \[`lfo`]\cr Output of the `lfo` method. #' @param ... Ignored. -#' @return A ggplot object. +#' @return A `ggplot` object. #' @export #' @examples #' data.table::setDTthreads(1) # For CRAN diff --git a/R/mcmc_diagnostics.R b/R/mcmc_diagnostics.R index e9a7434..2d9201c 100644 --- a/R/mcmc_diagnostics.R +++ b/R/mcmc_diagnostics.R @@ -1,21 +1,29 @@ #' Diagnostic Values of a Dynamite Model #' -#' Prints HMC diagnostics, and lists parameters with smallest effective sample +#' Prints HMC diagnostics and lists parameters with smallest effective sample #' sizes and largest Rhat values. See [hmc_diagnostics()] and #' [posterior::default_convergence_measures()] for details. #' #' @export #' @family diagnostics +#' @rdname mcmc_diagnostics #' @param x \[`dynamitefit`]\cr The model fit object. #' @param n \[`integer(1)`]\cr How many rows to print in #' parameter-specific convergence measures. The default is 3. Should be a #' positive (unrestricted) integer. +#' @param ... Ignored. #' @return Returns `x` (invisibly). #' @examples #' data.table::setDTthreads(1) # For CRAN #' mcmc_diagnostics(gaussian_example_fit) #' -mcmc_diagnostics <- function(x, n = 3L) { +mcmc_diagnostics <- function(x, ...) { + UseMethod("mcmc_diagnostics", x) +} + +#' @export +#' @rdname mcmc_diagnostics +mcmc_diagnostics.dynamitefit <- function(x, n = 3L, ...) { stopifnot_( !missing(x), "Argument {.arg x} is missing." @@ -71,12 +79,20 @@ mcmc_diagnostics <- function(x, n = 3L) { #' #' @export #' @family diagnostics +#' @rdname hmc_diagnostics #' @param x \[`dynamitefit`]\cr The model fit object. +#' @param ... Ignored. #' @return Returns `x` (invisibly). #' data.table::setDTthreads(1) # For CRAN #' hmc_diagnostics(gaussian_example_fit) #' -hmc_diagnostics <- function(x) { +hmc_diagnostics <- function(x, ...) { + UseMethod("hmc_diagnostics", x) +} + +#' @export +#' @rdname hmc_diagnostics +hmc_diagnostics.dynamitefit <- function(x, ...) { stopifnot_( !missing(x), "Argument {.arg x} is missing." diff --git a/R/plot.R b/R/plot.R index caac8d5..2db1eb5 100644 --- a/R/plot.R +++ b/R/plot.R @@ -501,7 +501,7 @@ plot_fixed <- function(coefs, level, alpha, facet, scales, n_params) { coefs$type[1L], alpha = "time-invariant intercepts", beta = "time-invariant regression coefficients", - cutpoints = "time-invariant cutpoints", + cutpoint = "time-invariant cutpoints", nu = "random intercepts", lambda = "latent factor loadings", "time-invariant parameters" @@ -559,7 +559,7 @@ plot_varying <- function(coefs, level, alpha, scales, n_params) { title_spec <- switch( coefs$type[1L], alpha = "time-varying intercepts", - cutpoints = "time-invariant cutpoints", + cutpoint = "time-invariant cutpoints", delta = "time-invariant regression coefficients", psi = "latent factors", "time-varying parameters" diff --git a/R/predict_helpers.R b/R/predict_helpers.R index 72ede1c..f290b10 100644 --- a/R/predict_helpers.R +++ b/R/predict_helpers.R @@ -644,7 +644,7 @@ prepare_eval_env_univariate <- function(e, resp, resp_levels, cvars, idx, type, eval_type) { alpha <- paste0("alpha_", resp) beta <- paste0("beta_", resp) - cutpoints <- paste0("cutpoints_", resp) + cutpoint <- paste0("cutpoint_", resp) delta <- paste0("delta_", resp) phi <- paste0("phi_", resp) sigma <- paste0("sigma_", resp) @@ -681,14 +681,14 @@ prepare_eval_env_univariate <- function(e, resp, resp_levels, cvars, stats::pnorm ) if (cvars$has_fixed_intercept) { - e$cutpoints <- samples[[cutpoints]][idx, , drop = FALSE] - e$cutpoints <- e$cutpoints[rep_len(e$n_draws, e$k), , drop = FALSE] + e$cutpoint <- samples[[cutpoint]][idx, , drop = FALSE] + e$cutpoint <- e$cutpoint[rep_len(e$n_draws, e$k), , drop = FALSE] e$alpha <- matrix(0.0, e$n_draws, 1L) } if (cvars$has_varying_intercept) { - e$cutpoints <- samples[[cutpoints]][idx, , , drop = FALSE] - e$cutpoints <- e$cutpoints[rep_len(e$n_draws, e$k), , , drop = FALSE] - e$alpha <- matrix(0.0, e$n_draws, dim(e$cutpoints)[2L]) + e$cutpoint <- samples[[cutpoint]][idx, , , drop = FALSE] + e$cutpoint <- e$cutpoint[rep_len(e$n_draws, e$k), , , drop = FALSE] + e$alpha <- matrix(0.0, e$n_draws, dim(e$cutpoint)[2L]) } } else { if (cvars$has_fixed_intercept) { @@ -1063,8 +1063,8 @@ predict_expr$fitted$categorical <- " " predict_expr$fitted$cumulative <- " - prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) - - cbind(invlink(xbeta - cutpoints{idx_cuts}), 0) + prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) - + cbind(invlink(xbeta - cutpoint{idx_cuts}), 0) for (s in 1:d) {{ data.table::set( x = out, @@ -1173,8 +1173,8 @@ predict_expr$predicted$categorical <- " " predict_expr$predicted$cumulative <- " - prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) - - cbind(invlink(xbeta - cutpoints{idx_cuts}), 0) + prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) - + cbind(invlink(xbeta - cutpoint{idx_cuts}), 0) data.table::set( x = out, i = idx_data, @@ -1322,8 +1322,8 @@ predict_expr$mean$categorical <- " " predict_expr$mean$cumulative <- " - prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) - - cbind(invlink(xbeta - cutpoints{idx_cuts}), 0) + prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) - + cbind(invlink(xbeta - cutpoint{idx_cuts}), 0) for (s in 1:d) {{ data.table::set( x = out, @@ -1462,8 +1462,8 @@ predict_expr$loglik$categorical <- " " predict_expr$loglik$cumulative <- " - prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) - - cbind(invlink(xbeta - cutpoints{idx_cuts}), 0) + prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) - + cbind(invlink(xbeta - cutpoint{idx_cuts}), 0) data.table::set( x = out, i = idx, diff --git a/R/prepare_stan_input.R b/R/prepare_stan_input.R index 85374a5..762533e 100644 --- a/R/prepare_stan_input.R +++ b/R/prepare_stan_input.R @@ -770,7 +770,7 @@ prepare_channel_cumulative <- function(y, Y, channel, sampling, fixed_cutpoints <- channel$has_fixed_intercept if (fixed_cutpoints) { cutpoint_priors <- data.frame( - parameter = paste0("cutpoints_", y, "_", seq_len(S_y - 1)), + parameter = paste0("cutpoint_", y, "_", seq_len(S_y - 1)), response = y, prior = "std_normal()", type = "cutpoint", @@ -793,13 +793,13 @@ prepare_channel_cumulative <- function(y, Y, channel, sampling, out$priors <- out$priors[out$priors$type != "alpha", ] out$channel$prior_distr$alpha_prior_distr <- NULL if (is.null(priors)) { - out$channel$prior_distr$cutpoints_prior_distr <- cutpoint_priors$prior - names(out$channel$prior_distr$cutpoints_prior_distr) <- cutpoint_priors$category + out$channel$prior_distr$cutpoint_prior_distr <- cutpoint_priors$prior + names(out$channel$prior_distr$cutpoint_prior_distr) <- cutpoint_priors$category out$priors <- rbind(cutpoint_priors, out$priors) } else { priors <- priors[priors$response == y, ] pdef <- priors[priors$type == "cutpoint", ] - out$channel$prior_distr$cutpoints_prior_distr <- pdef$prior + out$channel$prior_distr$cutpoint_prior_distr <- pdef$prior defaults <- rbind( cutpoint_priors, default_priors(y, channel, mean_gamma, sd_gamma, mean_y, sd_y)$priors, diff --git a/R/stanblocks_families.R b/R/stanblocks_families.R index 9a10f35..1d2e83e 100644 --- a/R/stanblocks_families.R +++ b/R/stanblocks_families.R @@ -632,20 +632,20 @@ loglik_lines_cumulative <- function(y, obs, idt, default, family, u <- default$u is_logit <- identical("logit", family$link) link <- ifelse_(is_logit, "logistic", "probit") - cutpoints <- ifelse_( + cutpoint <- ifelse_( has_varying_intercept, - glue::glue("cutpoints_{y}[t, ]"), - glue::glue("cutpoints_{y}") + glue::glue("cutpoint_{y}[t, ]"), + glue::glue("cutpoint_{y}") ) likelihood <- ifelse_( default$use_glm && is_logit, glue::glue( "ll += ordered_{link}_glm_l{u}pmf(y_{y}[t, {obs}] | X[t][{obs}, J_{y}], ", - "gamma__{y}, {cutpoints});" + "gamma__{y}, {cutpoint});" ), glue::glue( "ll += ordered_{link}_l{u}pmf(y_{y}[t, {obs}] | ", - "intercept_{y}, {cutpoints});" + "intercept_{y}, {cutpoint});" ) ) if (default$threading) { @@ -653,11 +653,11 @@ loglik_lines_cumulative <- function(y, obs, idt, default, family, default$fun_args, onlyif( has_fixed_intercept, - glue::glue("vector cutpoints_{y}") + glue::glue("vector cutpoint_{y}") ), onlyif( has_varying_intercept, - glue::glue("array[] vector cutpoints_{y}") + glue::glue("array[] vector cutpoint_{y}") ) ) } @@ -1459,7 +1459,7 @@ parameters_lines_cumulative <- function(y, idt, default, default, onlyif( has_fixed_intercept, - "ordered[S_{y} - 1] cutpoints_{y}; // Cutpoints" + "ordered[S_{y} - 1] cutpoint_{y}; // Cutpoints" ), .indent = idt(c(0, 1)) ) @@ -1761,14 +1761,14 @@ transformed_parameters_lines_cumulative <- function(y, categories, if (has_varying_intercept) { declare_cutpoints <- glue::glue( stan_array( - backend, "vector", "cutpoints_{y}", "T", "", "S_{y} - 1" + backend, "vector", "cutpoint_{y}", "T", "", "S_{y} - 1" ) ) assign_cutpoints <- vapply( seq_along(categories), function(s) { glue::glue( - "cutpoints_{y}[, {s}] = alpha_{y}_{s};" + "cutpoint_{y}[, {s}] = alpha_{y}_{s};" ) }, character(1L) @@ -1776,9 +1776,9 @@ transformed_parameters_lines_cumulative <- function(y, categories, state_cutpoints <- paste_rows( assign_cutpoints, "for (t in 1:T) {{", - "vector[S_{y}] tmp = exp(append_row(0, cutpoints_{y}[t, ]));", + "vector[S_{y}] tmp = exp(append_row(0, cutpoint_{y}[t, ]));", "for (s in 1:(S_{y} - 1)) {{", - "cutpoints_{y}[t, s] = log(sum(tmp[1:s]) / sum(tmp[(s + 1):S_{y}]));", + "cutpoint_{y}[t, s] = log(sum(tmp[1:s]) / sum(tmp[(s + 1):S_{y}]));", "}}", "}}", .indent = idt(c(1, 1, 2, 2, 3, 2, 1)) @@ -2158,12 +2158,14 @@ model_lines_cumulative <- function(y, obs, idt, priors, if (threading) { default$fun_call_args <- cs( default$fun_call_args, - glue::glue("cutpoints_{y}") + glue::glue("cutpoint_{y}") ) } - S <- length(prior_distr$cutpoints_prior_distr) + S <- length(prior_distr$cutpoint_prior_distr) paste_rows( - glue::glue("cutpoints_{y}[{seq_len(S)}] ~ {prior_distr$cutpoints_prior_distr};"), + glue::glue( + "cutpoint_{y}[{seq_len(S)}] ~ {prior_distr$cutpoint_prior_distr};" + ), priors, model_lines_default(y, obs, idt, threading, default, ...), .parse = FALSE diff --git a/man/as.data.frame.dynamitefit.Rd b/man/as.data.frame.dynamitefit.Rd index 691bebc..7f74835 100644 --- a/man/as.data.frame.dynamitefit.Rd +++ b/man/as.data.frame.dynamitefit.Rd @@ -81,7 +81,7 @@ Potential values for the \code{types} argument are: \itemize{ \item \code{alpha}\cr Intercept terms (time-invariant or time-varying). \item \code{beta}\cr Time-invariant regression coefficients. -\item \code{cutpoints}\cr Cutpoints for ordinal regression. +\item \code{cutpoint}\cr Cutpoints for ordinal regression. \item \code{delta}\cr Time-varying regression coefficients. \item \code{nu}\cr Group-level random effects. \item \code{lambda}\cr Factor loadings. diff --git a/man/hmc_diagnostics.Rd b/man/hmc_diagnostics.Rd index f1de3c5..8e6f1df 100644 --- a/man/hmc_diagnostics.Rd +++ b/man/hmc_diagnostics.Rd @@ -2,12 +2,17 @@ % Please edit documentation in R/mcmc_diagnostics.R \name{hmc_diagnostics} \alias{hmc_diagnostics} +\alias{hmc_diagnostics.dynamitefit} \title{HMC Diagnostics for a Dynamite Model} \usage{ -hmc_diagnostics(x) +hmc_diagnostics(x, ...) + +\method{hmc_diagnostics}{dynamitefit}(x, ...) } \arguments{ \item{x}{[\code{dynamitefit}]\cr The model fit object.} + +\item{...}{Ignored.} } \value{ Returns \code{x} (invisibly). diff --git a/man/lfo.Rd b/man/lfo.Rd index 11296d9..a9162ab 100644 --- a/man/lfo.Rd +++ b/man/lfo.Rd @@ -2,13 +2,20 @@ % Please edit documentation in R/lfo.R \name{lfo} \alias{lfo} +\alias{lfo.dynamitefit} \title{Approximate Leave-Future-Out (LFO) Cross-validation} \usage{ -lfo(x, L, verbose = TRUE, k_threshold = 0.7, ...) +lfo(x, ...) + +\method{lfo}{dynamitefit}(x, L, verbose = TRUE, k_threshold = 0.7, ...) } \arguments{ \item{x}{[\code{dynamitefit}]\cr The model fit object.} +\item{...}{Additional arguments passed to \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}} or +\code{\link[cmdstanr:model-method-sample]{cmdstanr::sample()}}, such as \code{chains} and \code{cores} (\code{parallel_chains} in +\code{cmdstanr}).} + \item{L}{[\code{integer(1)}]\cr Positive integer defining how many time points should be used for the initial fit.} @@ -17,10 +24,6 @@ the LFO computations to the console.} \item{k_threshold}{[\code{numeric(1)}]\cr Threshold for the Pareto k estimate triggering refit. Default is 0.7.} - -\item{...}{Additional arguments passed to \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}} or -\code{\link[cmdstanr:model-method-sample]{cmdstanr::sample()}}, such as \code{chains} and \code{cores} (\code{parallel_chains} in -\code{cmdstanr}).} } \value{ An \code{lfo} object which is a \code{list} with the following components: diff --git a/man/mcmc_diagnostics.Rd b/man/mcmc_diagnostics.Rd index 04db359..5cce21a 100644 --- a/man/mcmc_diagnostics.Rd +++ b/man/mcmc_diagnostics.Rd @@ -2,13 +2,18 @@ % Please edit documentation in R/mcmc_diagnostics.R \name{mcmc_diagnostics} \alias{mcmc_diagnostics} +\alias{mcmc_diagnostics.dynamitefit} \title{Diagnostic Values of a Dynamite Model} \usage{ -mcmc_diagnostics(x, n = 3L) +mcmc_diagnostics(x, ...) + +\method{mcmc_diagnostics}{dynamitefit}(x, n = 3L, ...) } \arguments{ \item{x}{[\code{dynamitefit}]\cr The model fit object.} +\item{...}{Ignored.} + \item{n}{[\code{integer(1)}]\cr How many rows to print in parameter-specific convergence measures. The default is 3. Should be a positive (unrestricted) integer.} @@ -17,7 +22,7 @@ positive (unrestricted) integer.} Returns \code{x} (invisibly). } \description{ -Prints HMC diagnostics, and lists parameters with smallest effective sample +Prints HMC diagnostics and lists parameters with smallest effective sample sizes and largest Rhat values. See \code{\link[=hmc_diagnostics]{hmc_diagnostics()}} and \code{\link[posterior:draws_summary]{posterior::default_convergence_measures()}} for details. } diff --git a/man/plot.lfo.Rd b/man/plot.lfo.Rd index 34bc6d4..37cd8e8 100644 --- a/man/plot.lfo.Rd +++ b/man/plot.lfo.Rd @@ -7,12 +7,12 @@ \method{plot}{lfo}(x, ...) } \arguments{ -\item{x}{[\code{lfo}]\cr Output from the \code{lfo} function.} +\item{x}{[\code{lfo}]\cr Output of the \code{lfo} method.} \item{...}{Ignored.} } \value{ -A ggplot object. +A \code{ggplot} object. } \description{ Plots Pareto k values per each time point (with one point per group), diff --git a/man/print.lfo.Rd b/man/print.lfo.Rd index 2c69d61..166a30d 100644 --- a/man/print.lfo.Rd +++ b/man/print.lfo.Rd @@ -7,7 +7,7 @@ \method{print}{lfo}(x, ...) } \arguments{ -\item{x}{x [\code{lfo}]\cr Output of the \code{lfo} method.} +\item{x}{[\code{lfo}]\cr Output of the \code{lfo} method.} \item{...}{Ignored.} } diff --git a/tests/testthat/test-errors.R b/tests/testthat/test-errors.R index 1b4a983..b37915e 100644 --- a/tests/testthat/test-errors.R +++ b/tests/testthat/test-errors.R @@ -989,17 +989,8 @@ test_that("output for missing argument fails", { "summary", "update" ) - non_s3_methods <- c( - "hmc_diagnostics", - "lfo", - "mcmc_diagnostics" - ) for (m in methods) { - call_fun <- ifelse_( - m %in% non_s3_methods, - m, - paste0(m, ".dynamitefit") - ) + call_fun <- paste0(m, ".dynamitefit") expect_error( do.call(call_fun, args = list()), "Argument `.+` is missing" @@ -1015,7 +1006,10 @@ test_that("output for non dynamitefit objects fails", { "coef", "fitted", "formula", + "hmc_diagnostics", + "lfo", "loo", + "mcmc_diagnostics", "ndraws", "nobs", "plot", @@ -1033,22 +1027,13 @@ test_that("output for non dynamitefit objects fails", { "summary", "update" ) - non_s3_methods <- c( - "hmc_diagnostics", - "lfo", - "mcmc_diagnostics" - ) for (m in methods) { args <- ifelse_( m %in% object_arg_methods, list(object = 1L), list(x = 1L) ) - call_fun <- ifelse_( - m %in% non_s3_methods, - m, - paste0(m, ".dynamitefit") - ) + call_fun <- paste0(m, ".dynamitefit") expect_error( do.call(call_fun, args = args), "Argument `.+` must be a object\\." @@ -1061,6 +1046,8 @@ test_that("output without Stan fit fails", { "as.data.frame", "as_draws_df", "fitted", + "lfo", + "loo", "predict", "ndraws" ) diff --git a/tests/testthat/test-extended.R b/tests/testthat/test-extended.R index 3c2e514..008fea6 100644 --- a/tests/testthat/test-extended.R +++ b/tests/testthat/test-extended.R @@ -235,7 +235,7 @@ test_that("time-invariant cumulative fit and predict work", { NA ) expect_error( - as.data.table(fit, types = "cutpoints"), + as.data.table(fit, types = "cutpoint"), NA ) expect_error( @@ -317,7 +317,7 @@ test_that("time-varying cutpoints for cumulative works", { NA ) expect_error( - as.data.table(fit, types = "cutpoints"), + as.data.table(fit, types = "cutpoint"), NA ) expect_error(