Skip to content

Commit

Permalink
mcmc_diagnostics, hmc_diagnostics and lfo as methods, cutpoints to cu…
Browse files Browse the repository at this point in the history
…tpoint, run-extended
  • Loading branch information
santikka committed May 13, 2024
1 parent b4c7725 commit 278442b
Show file tree
Hide file tree
Showing 17 changed files with 114 additions and 87 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ S3method(get_parameter_names,dynamitefit)
S3method(get_parameter_types,dynamitefit)
S3method(get_priors,dynamitefit)
S3method(get_priors,dynamiteformula)
S3method(hmc_diagnostics,dynamitefit)
S3method(lfo,dynamitefit)
S3method(loo,dynamitefit)
S3method(mcmc_diagnostics,dynamitefit)
S3method(ndraws,dynamitefit)
S3method(nobs,dynamitefit)
S3method(plot,dynamitefit)
Expand Down
2 changes: 1 addition & 1 deletion R/as_data_frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#'
#' * `alpha`\cr Intercept terms (time-invariant or time-varying).
#' * `beta`\cr Time-invariant regression coefficients.
#' * `cutpoints`\cr Cutpoints for ordinal regression.
#' * `cutpoint`\cr Cutpoints for ordinal regression.
#' * `delta`\cr Time-varying regression coefficients.
#' * `nu`\cr Group-level random effects.
#' * `lambda`\cr Factor loadings.
Expand Down
16 changes: 8 additions & 8 deletions R/as_data_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,9 @@ as_data_table_corr <- function(x, draws, n_draws, resps, ...) {
)
}

#' @describeIn as_data_table_default Data Table for a "cutpoints" Parameter
#' @describeIn as_data_table_default Data Table for a "cutpoint" Parameter
#' @noRd
as_data_table_cutpoints <- function(x, draws, response,
as_data_table_cutpoint <- function(x, draws, response,
n_draws, include_fixed, ...) {
channel <- get_channel(x, response)
S <- channel$S
Expand All @@ -726,7 +726,7 @@ as_data_table_cutpoints <- function(x, draws, response,
data.table::rbindlist(lapply(seq_len(S - 1L), function(i) {
idx <- (i - 1L) * n_time2 + seq_len(n_time2)
data.table::data.table(
parameter = paste0("cutpoints_", response),
parameter = paste0("cutpoint_", response),
value = c(
rep(NA, n_na),
c(draws[, , idx])
Expand All @@ -737,7 +737,7 @@ as_data_table_cutpoints <- function(x, draws, response,
}))
} else {
data.table::data.table(
parameter = paste0("cutpoints_", response),
parameter = paste0("cutpoint_", response),
category = rep(seq_len(S - 1L), each = n_draws),
value = c(draws)
)
Expand All @@ -752,7 +752,7 @@ all_types <- c(
"corr",
"corr_nu",
"corr_psi",
"cutpoints",
"cutpoint",
"delta",
"lambda",
"nu",
Expand All @@ -776,7 +776,7 @@ fixed_types <- c(
"corr",
"corr_nu",
"corr_psi",
"cutpoints",
"cutpoint",
"lambda",
"nu",
"omega",
Expand All @@ -794,15 +794,15 @@ fixed_types <- c(

varying_types <- c(
"alpha",
"cutpoints",
"cutpoint",
"delta",
"psi"
)

default_types <- c(
"alpha",
"beta",
"cutpoints",
"cutpoint",
"delta",
"lambda",
"nu",
Expand Down
18 changes: 12 additions & 6 deletions R/lfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
#' is triggered if pareto k values of any group exceeds the threshold.
#'
#' @export
#' @export lfo
#' @family diagnostics
#' @aliases lfo
#' @rdname lfo
#' @param x \[`dynamitefit`]\cr The model fit object.
#' @param L \[`integer(1)`]\cr Positive integer defining how many time points
#' should be used for the initial fit.
Expand Down Expand Up @@ -51,7 +50,13 @@
#' }
#' }
#'
lfo <- function(x, L, verbose = TRUE, k_threshold = 0.7, ...) {
lfo <- function(x, ...) {
UseMethod("lfo", x)
}

#' @export
#' @rdname lfo
lfo.dynamitefit <- function(x, L, verbose = TRUE, k_threshold = 0.7, ...) {
stopifnot_(
!missing(x),
"Argument {.arg x} is missing."
Expand Down Expand Up @@ -272,7 +277,8 @@ lfo <- function(x, L, verbose = TRUE, k_threshold = 0.7, ...) {
#' Print the results from the LFO
#'
#' Prints the summary of the leave-future-out cross-validation.
#' @param x x \[`lfo`]\cr Output of the `lfo` method.
#'
#' @param x \[`lfo`]\cr Output of the `lfo` method.
#' @param ... Ignored.
#' @return Returns `x` invisibly.
#' @export
Expand Down Expand Up @@ -305,9 +311,9 @@ print.lfo <- function(x, ...) {
#' Plots Pareto k values per each time point (with one point per group),
#' together with a horizontal line representing the used threshold.
#'
#' @param x \[`lfo`]\cr Output from the `lfo` function.
#' @param x \[`lfo`]\cr Output of the `lfo` method.
#' @param ... Ignored.
#' @return A ggplot object.
#' @return A `ggplot` object.
#' @export
#' @examples
#' data.table::setDTthreads(1) # For CRAN
Expand Down
22 changes: 19 additions & 3 deletions R/mcmc_diagnostics.R
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
#' Diagnostic Values of a Dynamite Model
#'
#' Prints HMC diagnostics, and lists parameters with smallest effective sample
#' Prints HMC diagnostics and lists parameters with smallest effective sample
#' sizes and largest Rhat values. See [hmc_diagnostics()] and
#' [posterior::default_convergence_measures()] for details.
#'
#' @export
#' @family diagnostics
#' @rdname mcmc_diagnostics
#' @param x \[`dynamitefit`]\cr The model fit object.
#' @param n \[`integer(1)`]\cr How many rows to print in
#' parameter-specific convergence measures. The default is 3. Should be a
#' positive (unrestricted) integer.
#' @param ... Ignored.
#' @return Returns `x` (invisibly).
#' @examples
#' data.table::setDTthreads(1) # For CRAN
#' mcmc_diagnostics(gaussian_example_fit)
#'
mcmc_diagnostics <- function(x, n = 3L) {
mcmc_diagnostics <- function(x, ...) {
UseMethod("mcmc_diagnostics", x)
}

#' @export
#' @rdname mcmc_diagnostics
mcmc_diagnostics.dynamitefit <- function(x, n = 3L, ...) {
stopifnot_(
!missing(x),
"Argument {.arg x} is missing."
Expand Down Expand Up @@ -71,12 +79,20 @@ mcmc_diagnostics <- function(x, n = 3L) {
#'
#' @export
#' @family diagnostics
#' @rdname hmc_diagnostics
#' @param x \[`dynamitefit`]\cr The model fit object.
#' @param ... Ignored.
#' @return Returns `x` (invisibly).
#' data.table::setDTthreads(1) # For CRAN
#' hmc_diagnostics(gaussian_example_fit)
#'
hmc_diagnostics <- function(x) {
hmc_diagnostics <- function(x, ...) {
UseMethod("hmc_diagnostics", x)
}

#' @export
#' @rdname hmc_diagnostics
hmc_diagnostics.dynamitefit <- function(x, ...) {
stopifnot_(
!missing(x),
"Argument {.arg x} is missing."
Expand Down
4 changes: 2 additions & 2 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ plot_fixed <- function(coefs, level, alpha, facet, scales, n_params) {
coefs$type[1L],
alpha = "time-invariant intercepts",
beta = "time-invariant regression coefficients",
cutpoints = "time-invariant cutpoints",
cutpoint = "time-invariant cutpoints",
nu = "random intercepts",
lambda = "latent factor loadings",
"time-invariant parameters"
Expand Down Expand Up @@ -559,7 +559,7 @@ plot_varying <- function(coefs, level, alpha, scales, n_params) {
title_spec <- switch(
coefs$type[1L],
alpha = "time-varying intercepts",
cutpoints = "time-invariant cutpoints",
cutpoint = "time-invariant cutpoints",
delta = "time-invariant regression coefficients",
psi = "latent factors",
"time-varying parameters"
Expand Down
28 changes: 14 additions & 14 deletions R/predict_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ prepare_eval_env_univariate <- function(e, resp, resp_levels, cvars,
idx, type, eval_type) {
alpha <- paste0("alpha_", resp)
beta <- paste0("beta_", resp)
cutpoints <- paste0("cutpoints_", resp)
cutpoint <- paste0("cutpoint_", resp)
delta <- paste0("delta_", resp)
phi <- paste0("phi_", resp)
sigma <- paste0("sigma_", resp)
Expand Down Expand Up @@ -681,14 +681,14 @@ prepare_eval_env_univariate <- function(e, resp, resp_levels, cvars,
stats::pnorm
)
if (cvars$has_fixed_intercept) {
e$cutpoints <- samples[[cutpoints]][idx, , drop = FALSE]
e$cutpoints <- e$cutpoints[rep_len(e$n_draws, e$k), , drop = FALSE]
e$cutpoint <- samples[[cutpoint]][idx, , drop = FALSE]
e$cutpoint <- e$cutpoint[rep_len(e$n_draws, e$k), , drop = FALSE]
e$alpha <- matrix(0.0, e$n_draws, 1L)
}
if (cvars$has_varying_intercept) {
e$cutpoints <- samples[[cutpoints]][idx, , , drop = FALSE]
e$cutpoints <- e$cutpoints[rep_len(e$n_draws, e$k), , , drop = FALSE]
e$alpha <- matrix(0.0, e$n_draws, dim(e$cutpoints)[2L])
e$cutpoint <- samples[[cutpoint]][idx, , , drop = FALSE]
e$cutpoint <- e$cutpoint[rep_len(e$n_draws, e$k), , , drop = FALSE]
e$alpha <- matrix(0.0, e$n_draws, dim(e$cutpoint)[2L])
}
} else {
if (cvars$has_fixed_intercept) {
Expand Down Expand Up @@ -1063,8 +1063,8 @@ predict_expr$fitted$categorical <- "
"

predict_expr$fitted$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) -
cbind(invlink(xbeta - cutpoint{idx_cuts}), 0)
for (s in 1:d) {{
data.table::set(
x = out,
Expand Down Expand Up @@ -1173,8 +1173,8 @@ predict_expr$predicted$categorical <- "
"

predict_expr$predicted$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) -
cbind(invlink(xbeta - cutpoint{idx_cuts}), 0)
data.table::set(
x = out,
i = idx_data,
Expand Down Expand Up @@ -1322,8 +1322,8 @@ predict_expr$mean$categorical <- "
"

predict_expr$mean$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) -
cbind(invlink(xbeta - cutpoint{idx_cuts}), 0)
for (s in 1:d) {{
data.table::set(
x = out,
Expand Down Expand Up @@ -1462,8 +1462,8 @@ predict_expr$loglik$categorical <- "
"

predict_expr$loglik$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
prob <- cbind(1, invlink(xbeta - cutpoint{idx_cuts})) -
cbind(invlink(xbeta - cutpoint{idx_cuts}), 0)
data.table::set(
x = out,
i = idx,
Expand Down
8 changes: 4 additions & 4 deletions R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ prepare_channel_cumulative <- function(y, Y, channel, sampling,
fixed_cutpoints <- channel$has_fixed_intercept
if (fixed_cutpoints) {
cutpoint_priors <- data.frame(
parameter = paste0("cutpoints_", y, "_", seq_len(S_y - 1)),
parameter = paste0("cutpoint_", y, "_", seq_len(S_y - 1)),
response = y,
prior = "std_normal()",
type = "cutpoint",
Expand All @@ -793,13 +793,13 @@ prepare_channel_cumulative <- function(y, Y, channel, sampling,
out$priors <- out$priors[out$priors$type != "alpha", ]
out$channel$prior_distr$alpha_prior_distr <- NULL
if (is.null(priors)) {
out$channel$prior_distr$cutpoints_prior_distr <- cutpoint_priors$prior
names(out$channel$prior_distr$cutpoints_prior_distr) <- cutpoint_priors$category
out$channel$prior_distr$cutpoint_prior_distr <- cutpoint_priors$prior
names(out$channel$prior_distr$cutpoint_prior_distr) <- cutpoint_priors$category
out$priors <- rbind(cutpoint_priors, out$priors)
} else {
priors <- priors[priors$response == y, ]
pdef <- priors[priors$type == "cutpoint", ]
out$channel$prior_distr$cutpoints_prior_distr <- pdef$prior
out$channel$prior_distr$cutpoint_prior_distr <- pdef$prior
defaults <- rbind(
cutpoint_priors,
default_priors(y, channel, mean_gamma, sd_gamma, mean_y, sd_y)$priors,
Expand Down
Loading

0 comments on commit 278442b

Please sign in to comment.