Skip to content

Commit

Permalink
Support the sparse mode for the main data matrix Y
Browse files Browse the repository at this point in the history
  • Loading branch information
statwangz committed Sep 19, 2023
1 parent cbb707e commit 9bd2361
Show file tree
Hide file tree
Showing 14 changed files with 458 additions and 89 deletions.
5 changes: 1 addition & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ export(appendMFAIR)
export(createMFAIR)
export(fitBack)
export(fitGreedy)
export(fitSFFully)
export(fitSFMissing)
export(getELBO)
export(getImportance)
export(getImportanceSF)
export(initSF)
export(predictFX)
export(predictFXSF)
export(projSparse)
export(updateMFAIR)
exportClasses(MFAIR)
exportClasses(MFAIRSingleFactor)
Expand Down
112 changes: 77 additions & 35 deletions R/mfairBackfitting.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#' Fit the MFAI model using backfitting algorithm.
#'
#' @import Matrix
#' @importFrom methods new
#' @importFrom rpart rpart.control
#'
Expand Down Expand Up @@ -27,23 +28,29 @@ fitBack <- function(object,
sf_para = list()) {
# Check K
if (object@K == 1) {
stop("The backfitting algorithm is equivalent to the greedy algorithm when rank K = 1!")
stop("The backfitting algorithm is equivalent to the greedy algorithm
when rank K = 1!")
} # End

# Check fitted functions F(), i.e., tree_0 and tree_lists
if (length(object@tree_0) == 0) {
object@tree_0 <- matrix(0, nrow = 1, ncol = object@K)
warning("The previous tree_0 (i.e., fitted functions) may not be saved!\n")
warning("The new tree_lists obtained after the backfitting algorithm may not accurately predict the new sample with auxiliary covariates.!\n")
warning("The new tree_lists obtained after the backfitting algorithm
may not accurately predict the new sample
with auxiliary covariates.!\n")
}
if (length(object@tree_lists) == 0) {
object@tree_lists <- lapply(1:object@K,
FUN = function(x) {
list()
}
)
warning("The previous tree_lists (i.e., fitted functions) may not be saved!\n")
warning("The new tree_lists obtained after the backfitting algorithm may not accurately predict the new sample with auxiliary covariates.!\n")
warning("The previous tree_lists (i.e., fitted functions)
may not be saved!\n")
warning("The new tree_lists obtained after the backfitting algorithm
may not accurately predict the new sample
with auxiliary covariates.!\n")
}

# Set up parameters for the gradient boosting part
Expand All @@ -67,7 +74,11 @@ fitBack <- function(object,

# Will be used for the partially observed matrix fitting
if (object@Y_missing) {
obs_indices <- !is.na(object@Y)
if (object@Y_sparse) {
obs_indices <- as.matrix(summary(Y)[, c(1, 2)])
} else {
obs_indices <- !is.na(object@Y)
}
}

tau <- object@tau
Expand All @@ -76,8 +87,7 @@ fitBack <- function(object,
# Begin backfitting algorithm
for (iter in 1:iter_max_bf) {
for (k in 1:object@K) {
# The residual (low-rank approximation using all factors but k-th)
R <- object@Y + object@Y_mean - predict(object, which_factors = -k)
# Initialize
mfair_sf <- new(
Class = "MFAIRSingleFactor",
Y_missing = object@Y_missing,
Expand All @@ -93,18 +103,20 @@ fitBack <- function(object,
tree_list = object@tree_lists[[k]]
)

if (object@Y_missing) {
# mfair_sf <- fitSFMissing(R, obs_indices, object@X, mfair_sf,
# object@learning_rate,
# tree_parameters = object@tree_parameters,
# ...
# )
if (object@Y_sparse) {
# The residual (low-rank approximation using all factors but k-th)
R <- object@Y -
projSparse(
predict(object, which_factors = -k, add_mean = FALSE),
obs_indices
)

mfair_sf <- do.call(
what = "fitSFMissing",
what = "fitSFSparse",
args = append(
list(
Y = R, obs_indices = obs_indices, X = object@X,
init = mfair_sf,
Y = R, X = object@X, init = mfair_sf,
obs_indices = obs_indices,
stage1 = FALSE,
learning_rate = object@learning_rate,
tree_parameters = object@tree_parameters
Expand All @@ -113,25 +125,51 @@ fitBack <- function(object,
)
)
} else {
# mfair_sf <- fitSFFully(R, object@X, mfair_sf,
# object@learning_rate,
# tree_parameters = object@tree_parameters,
# ...
# )
mfair_sf <- do.call(
what = "fitSFFully",
args = append(
list(
Y = R, X = object@X,
init = mfair_sf,
stage1 = FALSE,
learning_rate = object@learning_rate,
tree_parameters = object@tree_parameters
),
sf_para
# The residual (low-rank approximation using all factors but k-th)
R <- object@Y -
predict(object, which_factors = -k, add_mean = FALSE)

if (object@Y_missing) {
# mfair_sf <- fitSFMissing(R, obs_indices, object@X, mfair_sf,
# object@learning_rate,
# tree_parameters = object@tree_parameters,
# ...
# )
mfair_sf <- do.call(
what = "fitSFMissing",
args = append(
list(
Y = R, X = object@X, init = mfair_sf,
obs_indices = obs_indices,
stage1 = FALSE,
learning_rate = object@learning_rate,
tree_parameters = object@tree_parameters
),
sf_para
)
)
)
} else {
# mfair_sf <- fitSFFully(R, object@X, mfair_sf,
# object@learning_rate,
# tree_parameters = object@tree_parameters,
# ...
# )
mfair_sf <- do.call(
what = "fitSFFully",
args = append(
list(
Y = R, X = object@X,
init = mfair_sf,
stage1 = FALSE,
learning_rate = object@learning_rate,
tree_parameters = object@tree_parameters
),
sf_para
)
)
}
}

object <- updateMFAIR(object, mfair_sf, k)

if (verbose_bf_inner) {
Expand All @@ -145,9 +183,13 @@ fitBack <- function(object,
tau_new <- object@tau
beta_new <- object@beta

gap <- mean(abs(tau_new - tau) / abs(tau)) + mean(abs(beta_new - beta) / abs(beta))
gap <- mean(abs(tau_new - tau) / abs(tau)) +
mean(abs(beta_new - beta) / abs(beta))
if (verbose_bf_outer) {
cat("Iteration: ", iter, ", relative difference of model parameters: ", gap, ".\n", sep = "")
cat("Iteration: ", iter,
", relative difference of model parameters: ", gap, ".\n",
sep = ""
)
}
if (gap < tol_bf) {
break
Expand Down
37 changes: 25 additions & 12 deletions R/mfairELBO.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@
#'
#' @param Y Observed main data matrix.
#' @param object MFAIRSingleFactor object containing the information about the fitted single factor MFAI model.
#' @param obs_indices Indices of the observed entries in the main data matrix Y. The default value is NULL and used only when Y is stored in the sparse mode.
#'
#' @return Numeric. The ELBO.
#' @export
#'
getELBO <- function(Y, object) {
getELBO <- function(Y, object, obs_indices) {
N <- nrow(Y)
M <- ncol(Y)

# n_missing <- sum(is.na(Y))
# if (n_missing >= 1) {
# Y_missing <- TRUE
# } else {
# Y_missing <- FALSE
# }

mu <- object@mu
mu_sq <- (object@mu)^2
nu <- object@nu
Expand All @@ -33,15 +27,34 @@ getELBO <- function(Y, object) {
if (object@Y_missing) {
n_obs <- object@n_obs
elbo1 <- -n_obs * log(2 * pi / tau) / 2
elbo2 <- -tau * sum((Y - as.matrix(mu) %*% t(nu))^2 + as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) - as.matrix(mu_sq) %*% t(nu_sq), na.rm = TRUE) / 2
elbo3 <- -N * log(2 * pi / beta) / 2 - beta * (sum(mu_sq) + sum(a_sq) - 2 * sum(mu * FX) + sum(FX^2)) / 2

if (!is.null(obs_indices)) { # Sparse mode
elbo2 <- -tau *
sum((Y - projSparse(as.matrix(mu) %*% t(nu), obs_indices))^2 +
projSparse(as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq), obs_indices) -
projSparse(as.matrix(mu_sq) %*% t(nu_sq), obs_indices))
} else {
elbo2 <- -tau * sum(
(Y - as.matrix(mu) %*% t(nu))^2 +
as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) -
as.matrix(mu_sq) %*% t(nu_sq),
na.rm = TRUE
) / 2
}
elbo3 <- -N * log(2 * pi / beta) / 2 -
beta * (sum(mu_sq) + sum(a_sq) -
2 * sum(mu * FX) + sum(FX^2)) / 2
elbo4 <- -M * log(2 * pi) / 2 - (sum(nu_sq) + sum(b_sq)) / 2
elbo5 <- sum(log(2 * pi * a_sq)) / 2 + N / 2
elbo6 <- sum(log(2 * pi * b_sq)) / 2 + M / 2
} else {
elbo1 <- -N * M * log(2 * pi / tau) / 2
elbo2 <- -tau * sum((Y - as.matrix(mu) %*% t(nu))^2 + as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) - as.matrix(mu_sq) %*% t(nu_sq)) / 2
elbo3 <- -N * log(2 * pi / beta) / 2 - beta * (sum(mu_sq) + N * a_sq - 2 * sum(mu * FX) + sum(FX^2)) / 2
elbo2 <- -tau * sum((Y - as.matrix(mu) %*% t(nu))^2 +
as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) -
as.matrix(mu_sq) %*% t(nu_sq)) / 2
elbo3 <- -N * log(2 * pi / beta) / 2 -
beta * (sum(mu_sq) + N * a_sq -
2 * sum(mu * FX) + sum(FX^2)) / 2
elbo4 <- -M * log(2 * pi) / 2 - (sum(nu_sq) + M * b_sq) / 2
elbo5 <- N * log(2 * pi * a_sq) / 2 + N / 2
elbo6 <- M * log(2 * pi * b_sq) / 2 + M / 2
Expand Down
38 changes: 29 additions & 9 deletions R/mfairGreedy.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fitGreedy <- function(object, K_max = NULL,
# Check whether partially observed main data matrix and record the indices
if (object@Y_missing) {
if (object@Y_sparse) {
obs_indices <- NULL # Sparse mode does not need indices
obs_indices <- as.matrix(summary(Y)[, c(1, 2)])
} else {
obs_indices <- !is.na(object@Y)
}
Expand All @@ -39,7 +39,8 @@ fitGreedy <- function(object, K_max = NULL,

# Check K_max
if (object@K_max > object@N || object@K_max > object@M) {
warning("The maximum rank allowed can not be larger than the rank of the main data matrix!\n")
warning("The maximum rank allowed can not
be larger than the rank of the main data matrix!\n")
object@K_max <- min(object@N, object@M)
warning("Reset K_max = ", object@K_max, "!\n")
}
Expand All @@ -58,9 +59,11 @@ fitGreedy <- function(object, K_max = NULL,
} else {
need_init <- rep(FALSE, object@K_max)
if (init_length < object@K_max) {
warning("Only the first ", init_length, " factors have been initialized, which is less than K_max!\n")
warning("Only the first ", init_length, " factors have been initialized,
which is less than K_max!\n")
need_init[-(1:init_length)] <- TRUE
warning("The remaining factors will be initialized automatically if needed!\n")
warning("The remaining factors will be initialized automatically
if needed!\n")
}
}

Expand Down Expand Up @@ -102,13 +105,26 @@ fitGreedy <- function(object, K_max = NULL,
}

# Fit the single factor MFAI model
if (object@Y_missing) {
if (object@Y_sparse) { # The main data matrix is partially observed and stored in the sparse mode
mfair_sf <- do.call(
what = "fitSFSparse",
args = append(
list(
Y = R, X = object@X, init = init,
obs_indices = obs_indices,
learning_rate = object@learning_rate,
tree_parameters = object@tree_parameters
),
sf_para
)
)
} else if (object@Y_missing) { # The main data matrix is partially observed but not stored in the sparse mode
mfair_sf <- do.call(
what = "fitSFMissing",
args = append(
list(
Y = R, obs_indices = obs_indices,
X = object@X, init = init,
Y = R, X = object@X, init = init,
obs_indices = obs_indices,
learning_rate = object@learning_rate,
tree_parameters = object@tree_parameters
),
Expand All @@ -120,7 +136,7 @@ fitGreedy <- function(object, K_max = NULL,
# tree_parameters = object@tree_parameters,
# ...
# )
} else {
} else { # The main data matrix is fully observed
mfair_sf <- do.call(
what = "fitSFFully",
args = append(
Expand Down Expand Up @@ -155,7 +171,11 @@ fitGreedy <- function(object, K_max = NULL,
}

# Prepare for the fitting for the next factor
R <- R - Y_k
if (object@Y_sparse) {
R <- R - projSparse(Y_k, obs_indices)
} else {
R <- R - Y_k
}

# Save the information about the fitted single factor MFAI model
object <- appendMFAIR(object, mfair_sf)
Expand Down
1 change: 0 additions & 1 deletion R/mfairImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ getImportance <- function(object, which_factors = seq_len(object@K)) {
#' @param variables_names The names of the auxiliary covariates.
#'
#' @return Importance score vector. Each entry is the importance score of one auxiliary covariate.
#' @export
#'
getImportanceSF <- function(tree_list, variables_names) {
importance_list <- lapply(tree_list,
Expand Down
6 changes: 3 additions & 3 deletions R/mfairInitialization.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ createMFAIR <- function(Y, X,
message("The main data matrix Y has been stored in the sparse mode!")
} else if (Y_sparse == TRUE) { # Y is not in sparse mode, but we want it to be
obs_tf <- !is.na(Y) # Indicates whether observed or missing
obs_idx <- which(obs_tf, arr.ind = TRUE) # Indices of observed entries
obs_indices <- which(obs_tf, arr.ind = TRUE) # Indices of observed entries
Y <- Matrix::sparseMatrix(
i = obs_idx[, "row"],
j = obs_idx[, "col"],
i = obs_indices[, "row"], j = obs_indices[, "col"],
x = Y[obs_tf],
dims = c(N, M),
symmetric = FALSE, triangular = FALSE,
Expand Down Expand Up @@ -121,6 +120,7 @@ createMFAIR <- function(Y, X,

#' Initialize the parameters for the single factor MAFI model.
#'
#' @import Matrix
#' @importFrom stats rnorm var
#' @importFrom methods new
#'
Expand Down
Loading

0 comments on commit 9bd2361

Please sign in to comment.