Skip to content

Commit

Permalink
droplevels in full_model.matrix, run-extended
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed May 30, 2024
1 parent 6a01d3e commit 7928f95
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 20 deletions.
18 changes: 12 additions & 6 deletions R/model_matrix.R
Original file line number Diff line number Diff line change
@@ -1,35 +1,41 @@
#' Combine `model.matrix` Objects of All Formulas of a `dynamiteformula`
#'
#' @inheritParams dynamite
#' @inheritParams prepare_stan_input
#' @srrstats {RE1.3, RE1.3a} `full_model.matrix` preserves relevant attributes.
#' @noRd
full_model.matrix <- function(dformula, data, verbose) {
full_model.matrix <- function(dformula, data, group_var, fixed, verbose) {
model_matrices <- vector(mode = "list", length = length(dformula))
model_matrices_type <- vector(mode = "list", length = length(dformula))
types <- c("fixed", "varying", "random")
idx <- data[,
.I[base::seq.int(fixed + 1L, .N)],
by = group,
env = list(fixed = fixed, group = group_var)
]$V1
data_nonfixed <- droplevels(data[idx, , env = list(idx = idx)])
for (i in seq_along(dformula)) {
mm <- stats::model.matrix.lm(
dformula[[i]]$formula,
data = data,
data = data_nonfixed,
na.action = na.pass
)
if (verbose) {
test_collinearity(dformula[[i]]$resp, mm, data)
test_collinearity(dformula[[i]]$resp, mm, data_nonfixed)
}
model_matrices_type[[i]] <- list()
for (type in c("fixed", "varying", "random")) {
type_formula <- get_type_formula(dformula[[i]], type)
if (!is.null(type_formula)) {
model_matrices_type[[i]][[type]] <- stats::model.matrix.lm(
type_formula,
data = data,
data = data_nonfixed,
na.action = na.pass
)
}
}
tmp <- do.call(cbind, model_matrices_type[[i]])
ifelse_(identical(length(tmp), 0L),
model_matrices[[i]] <- matrix(nrow = nrow(mm), ncol = 0),
model_matrices[[i]] <- matrix(nrow = nrow(mm), ncol = 0L),
model_matrices[[i]] <- tmp
)
}
Expand Down
5 changes: 2 additions & 3 deletions R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ prepare_stan_input <- function(dformula, data, group_var, time_var,
"Can't find variable{?s} {.var {resp[resp_missing]}} in {.arg data}."
)
specials <- lapply(dformula, evaluate_specials, data = data)
model_matrix <- full_model.matrix(dformula, data, verbose)
model_matrix <- full_model.matrix(dformula, data, group_var, fixed, verbose)
cg <- attr(dformula, "channel_groups")
n_cg <- n_unique(cg)
n_channels <- length(resp_names)
Expand Down Expand Up @@ -85,8 +85,7 @@ prepare_stan_input <- function(dformula, data, group_var, time_var,
N <- n_unique(group)
K <- ncol(model_matrix)
X <- model_matrix[, ]
dim(X) <- c(T_full, N, K)
X <- X[T_idx, , , drop = FALSE]
dim(X) <- c(T_full - fixed, N, K)
x_tmp <- X[1L, , , drop = FALSE]
sd_x <- pmax(
stats::setNames(apply(X, 3L, sd, na.rm = TRUE), colnames(model_matrix)),
Expand Down
Binary file modified data/categorical_example_fit.rda
Binary file not shown.
Binary file modified data/gaussian_example_fit.rda
Binary file not shown.
Binary file modified data/multichannel_example_fit.rda
Binary file not shown.
34 changes: 23 additions & 11 deletions tests/testthat/test-warnings.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,32 @@ test_that("factor time conversion warns", {
test_that("perfect collinearity warns", {
f1 <- obs(y ~ -1 + x + z, family = "gaussian")
f2 <- obs(y ~ z, family = "gaussian")
test_data1 <- data.frame(y = rnorm(10), x = rep(1, 10), z = rep(2, 10))
test_data2 <- data.frame(y = rep(1, 10), x = rep(1, 10), z = rnorm(10))
test_data1 <- data.table::data.table(
y = rnorm(10),
x = rep(1, 10),
z = rep(2, 10),
id = 1L
)
test_data2 <- data.table::data.table(
y = rep(1, 10),
x = rep(1, 10),
z = rnorm(10),
id = 1L
)
expect_warning(
full_model.matrix(f1, test_data1, TRUE),
full_model.matrix(f1, test_data1, "id", 0L, TRUE),
"Perfect collinearity found between predictor variables of channel `y`\\."
)
expect_warning(
full_model.matrix(f2, test_data2, TRUE),
full_model.matrix(f2, test_data2, "id", 0L, TRUE),
paste0(
"Perfect collinearity found between response and predictor variable:\n",
"i Response variable `y` is perfectly collinear ",
"with predictor variable `\\(Intercept\\)`\\."
)
)
expect_warning(
full_model.matrix(f1, test_data2, TRUE),
full_model.matrix(f1, test_data2, "id", 0L, TRUE),
paste0(
"Perfect collinearity found between response and predictor variable:\n",
"i Response variable `y` is perfectly collinear ",
Expand All @@ -50,14 +60,15 @@ test_that("perfect collinearity warns", {

test_that("too few observations warns", {
f <- obs(y ~ x + z + w, family = "gaussian")
test_data <- data.frame(
test_data <- data.table::data.table(
y = rnorm(3),
x = rnorm(3),
z = rnorm(3),
w = rnorm(3)
w = rnorm(3),
id = 1L
)
expect_warning(
full_model.matrix(f, test_data, TRUE),
full_model.matrix(f, test_data, "id", 0L, TRUE),
paste0(
"Number of non-missing observations 3 in channel `y` ",
"is less than 4, the number of predictors \\(including possible ",
Expand All @@ -68,13 +79,14 @@ test_that("too few observations warns", {

test_that("zero predictor warns", {
f <- obs(y ~ -1 + x + z, family = "gaussian")
test_data <- data.frame(
test_data <- data.table::data.table(
y = rnorm(6),
x = c(NA, rnorm(2), NA, rnorm(2)),
z = factor(1:3)
z = factor(1:3),
id = 1L
)
expect_warning(
full_model.matrix(f, test_data, TRUE),
full_model.matrix(f, test_data, "id", 0L, TRUE),
paste0(
"Predictor `z1` contains only zeros in the complete case rows of the ",
"design matrix for the channel `y`\\."
Expand Down

0 comments on commit 7928f95

Please sign in to comment.