diff --git a/R/mfairELBO.R b/R/mfairELBO.R index 4b72e4d..98ff4e5 100644 --- a/R/mfairELBO.R +++ b/R/mfairELBO.R @@ -30,14 +30,20 @@ getELBO <- function(Y, object, obs_indices = NULL) { if (!is.null(obs_indices)) { # Sparse mode elbo2 <- -tau * - sum((Y - projSparse(as.matrix(mu) %*% t(nu), obs_indices))^2 + - projSparse(as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq), obs_indices) - - projSparse(as.matrix(mu_sq) %*% t(nu_sq), obs_indices)) + sum((Y - projSparse( + tcrossprod(as.matrix(mu), as.matrix(nu)), obs_indices + ))^2 + + projSparse( + tcrossprod(as.matrix(mu_sq + a_sq), as.matrix(nu_sq + b_sq)), obs_indices + ) - + projSparse( + tcrossprod(as.matrix(mu_sq), as.matrix(nu_sq)), obs_indices + )) } else { elbo2 <- -tau * sum( - (Y - as.matrix(mu) %*% t(nu))^2 + - as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) - - as.matrix(mu_sq) %*% t(nu_sq), + (Y - tcrossprod(as.matrix(mu), as.matrix(nu)))^2 + + tcrossprod(as.matrix(mu_sq + a_sq), as.matrix(nu_sq + b_sq)) - + tcrossprod(as.matrix(mu_sq), as.matrix(nu_sq)), na.rm = TRUE ) / 2 } @@ -49,9 +55,9 @@ getELBO <- function(Y, object, obs_indices = NULL) { elbo6 <- sum(log(2 * pi * b_sq)) / 2 + M / 2 } else { elbo1 <- -N * M * log(2 * pi / tau) / 2 - elbo2 <- -tau * sum((Y - as.matrix(mu) %*% t(nu))^2 + - as.matrix(mu_sq + a_sq) %*% t(nu_sq + b_sq) - - as.matrix(mu_sq) %*% t(nu_sq)) / 2 + elbo2 <- -tau * sum((Y - tcrossprod(as.matrix(mu), as.matrix(nu)))^2 + + tcrossprod(as.matrix(mu_sq + a_sq), as.matrix(nu_sq + b_sq)) - + tcrossprod(as.matrix(mu_sq), as.matrix(nu_sq))) / 2 elbo3 <- -N * log(2 * pi / beta) / 2 - beta * (sum(mu_sq) + N * a_sq - 2 * sum(mu * FX) + sum(FX^2)) / 2