Skip to content

Commit

Permalink
Use crossprod() and tcrossprod() to replace t(x) %*% y and x %*% t(y)…
Browse files Browse the repository at this point in the history
… respectively
  • Loading branch information
statwangz committed Jan 30, 2024
1 parent bf89d38 commit ee7702e
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions R/mfairELBO.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@ getELBO <- function(Y, object, obs_indices = NULL) {

if (!is.null(obs_indices)) { # Sparse mode
elbo2 <- -tau *
sum((Y - projSparse(as.matrix(mu) %*% t(nu), obs_indices))^2 +
projSparse(as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq), obs_indices) -
projSparse(as.matrix(mu_sq) %*% t(nu_sq), obs_indices))
sum((Y - projSparse(
tcrossprod(as.matrix(mu), as.matrix(nu)), obs_indices
))^2 +
projSparse(
tcrossprod(as.matrix(mu_sq + a_sq), as.matrix(nu_sq + b_sq)), obs_indices
) -
projSparse(
tcrossprod(as.matrix(mu_sq), as.matrix(nu_sq)), obs_indices
))
} else {
elbo2 <- -tau * sum(
(Y - as.matrix(mu) %*% t(nu))^2 +
as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) -
as.matrix(mu_sq) %*% t(nu_sq),
(Y - tcrossprod(as.matrix(mu), as.matrix(nu)))^2 +
tcrossprod(as.matrix(mu_sq + a_sq), as.matrix(nu_sq + b_sq)) -
tcrossprod(as.matrix(mu_sq), as.matrix(nu_sq)),
na.rm = TRUE
) / 2
}
Expand All @@ -49,9 +55,9 @@ getELBO <- function(Y, object, obs_indices = NULL) {
elbo6 <- sum(log(2 * pi * b_sq)) / 2 + M / 2
} else {
elbo1 <- -N * M * log(2 * pi / tau) / 2
elbo2 <- -tau * sum((Y - as.matrix(mu) %*% t(nu))^2 +
as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) -
as.matrix(mu_sq) %*% t(nu_sq)) / 2
elbo2 <- -tau * sum((Y - tcrossprod(as.matrix(mu), as.matrix(nu)))^2 +
tcrossprod(as.matrix(mu_sq + a_sq), as.matrix(nu_sq + b_sq)) -
tcrossprod(as.matrix(mu_sq), as.matrix(nu_sq))) / 2
elbo3 <- -N * log(2 * pi / beta) / 2 -
beta * (sum(mu_sq) + N * a_sq -
2 * sum(mu * FX) + sum(FX^2)) / 2
Expand Down

0 comments on commit ee7702e

Please sign in to comment.