diff --git a/DESCRIPTION b/DESCRIPTION index 5b98635..33c8141 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,8 @@ Suggests: pheatmap, reshape2, rmarkdown, - scales + scales, + softImpute VignetteBuilder: knitr URL: https://yanglabhkust.github.io/mfair/ BugReports: https://github.com/YangLabHKUST/mfair/issues diff --git a/NAMESPACE b/NAMESPACE index 86c4a8a..a930419 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,11 +8,14 @@ export(getELBO) export(getImportance) export(initSF) export(predictFX) +export(softImputeCV) +export(softImputeCV_sparse) export(updateMFAIR) exportClasses(MFAIR) exportClasses(MFAIRSingleFactor) exportMethods(predict) import(Matrix) +import(softImpute) importFrom(dplyr,left_join) importFrom(methods,new) importFrom(rpart,rpart) diff --git a/R/softImputeCV.R b/R/softImputeCV.R new file mode 100644 index 0000000..3d8b576 --- /dev/null +++ b/R/softImputeCV.R @@ -0,0 +1,193 @@ +#' Cross-validation for softImpute +#' +#' @import softImpute +#' +#' @param Y A matrix. The main data matrix of N samples and M features. +#' @param rank_max An integer. The maximum rank allowed. +#' @param lambda_range A vector containing the minimal and maximal value of the parameter lambda. +#' @param nfold An integer. The total number of validation sets. +#' @param para_length An integer. The total number of parameter lambda. +#' +#' @return A list containing the cross-validation results. +#' @export +#' +softImputeCV <- function(Y, rank_max = NULL, lambda_range = NULL, + nfold = 10, para_length = 100) { + N <- dim(Y)[1] + M <- dim(Y)[2] + # Report settings + message("Info: Matrix dimension: ", N, " * ", M) + message("Info: Number of cv folds: ", nfold) + + if (is.null(rank_max)) { + rank_max <- min(N, M) - 1 + } + + if (is.null(lambda_range)) { + hi_fit <- softImpute(Y, rank.max = rank_max) + lambda_min <- tail(hi_fit$d, 1) / 100 + lambda_max <- mean((hi_fit$d)[1:2]) + } else { + lambda_min <- lambda_range[1] + lambda_max <- lambda_range[2] + } + lambda_list <- exp(seq(log(lambda_min), + log(lambda_max), + length.out = para_length + )) + + # Test RMSE + test_err <- matrix(0, nrow = nfold, ncol = para_length) + rownames(test_err) <- paste0("fold_", 1:nfold) + colnames(test_err) <- paste0("lambda_", 1:para_length) + + # Not NA idx + idx_all <- which(!is.na(Y)) + n <- length(idx_all) + # nfold + idx_group <- ceiling(sample(1:n) / n * nfold) + + cat("Start cross validation process...... total", nfold, "validation sets \n") + + for (i in 1:nfold) { + cat("Validation set", i, "... \n") + + Y_train <- Y_test <- Y + + Y_train[idx_all][idx_group == i] <- NA + Y_test[idx_all][idx_group != i] <- NA + + for (j in seq(from = para_length, to = 1, by = -1)) { + if (j == para_length) { + si_fit <- softImpute(Y_train, + rank.max = rank_max, + lambda = lambda_list[j] + ) + } else { + # Warm start + si_fit <- softImpute(Y_train, + rank.max = rank_max, + lambda = lambda_list[j], warm.start = si_fit + ) + } + + test_err[i, j] <- sqrt(mean((Y_test - complete(Y_train, si_fit))^2, na.rm = TRUE)) + } + } + + cv_sd <- sqrt(apply(test_err, MARGIN = 2, var) / (nfold - 1)) + cv_mean <- colMeans(test_err) + idx_min <- which.min(cv_mean) + lambda_best <- lambda_list[idx_min] + + si_cv_results <- list( + cv_mean = cv_mean, cv_sd = cv_sd, + lambda_list = lambda_list, lambda_best = lambda_best, + test_err = test_err + ) + + return(si_cv_results) +} + +#' Sparse version of cross-validation for softImpute +#' +#' @import Matrix +#' @import softImpute +#' +#' @param Y A Matrix::dgCMatrix. The main data matrix of N samples and M features. +#' @param rank_max An integer. The maximum rank allowed. +#' @param lambda_range A vector containing the minimal and maximal value of the parameter lambda. +#' @param nfold An integer. The total number of validation sets. +#' @param para_length An integer. The total number of parameter lambda. +#' +#' @return A list containing the cross-validation results. +#' @export +#' +softImputeCV_sparse <- function(Y, rank_max = NULL, lambda_range = NULL, + nfold = 10, para_length = 100) { + N <- dim(Y)[1] + M <- dim(Y)[2] + # Report settings + message("Info: Matrix dimension: ", N, " * ", M) + message("Info: Number of cv folds: ", nfold) + + if (is.null(rank_max)) { + rank_max <- min(N, M) - 1 + } + + if (is.null(lambda_range)) { + hi_fit <- softImpute(Y, rank.max = rank_max) + lambda_min <- tail(hi_fit$d, 1) / 100 + lambda_max <- mean((hi_fit$d)[1:2]) + } else { + lambda_min <- lambda_range[1] + lambda_max <- lambda_range[2] + } + lambda_list <- exp(seq(log(lambda_min), + log(lambda_max), + length.out = para_length + )) + + # Test RMSE + test_err <- matrix(0, nrow = nfold, ncol = para_length) + rownames(test_err) <- paste0("fold_", 1:nfold) + colnames(test_err) <- paste0("lambda_", 1:para_length) + + # Not NA idx + idx_all <- as.matrix(summary(Y)[, c(1, 2)]) + n <- nrow(idx_all) + # nfold + idx_group <- ceiling(sample(1:n) / n * nfold) + + cat("Start cross validation process...... total", nfold, "validation sets \n") + + for (i in 1:nfold) { + cat("Validation set", i, "... \n") + + train_idx <- which(idx_group != i) + train_mean <- mean(Y@x[train_idx]) # Mean value of the training set + + Y_train <- Incomplete( + i = idx_all[train_idx, 1], + j = idx_all[train_idx, 2], + x = Y@x[train_idx] - train_mean + ) + # Y_test <- Incomplete(i = idx_all[-train_idx, 1], + # j = idx_all[-train_idx, 2], + # x = Y@x[-train_idx]) + + for (j in seq(from = para_length, to = 1, by = -1)) { + if (j == para_length) { + si_fit <- softImpute(Y_train, + rank.max = rank_max, + lambda = lambda_list[j] + ) + } else { + # Warm start + si_fit <- softImpute(Y_train, + rank.max = rank_max, + lambda = lambda_list[j], warm.start = si_fit + ) + } + + Y_hat <- impute( + si_fit, + i = idx_all[-train_idx, 1], j = idx_all[-train_idx, 2] + ) + train_mean # The prediction cooresponding to the test set + test_err[i, j] <- sqrt(mean((Y@x[-train_idx] - Y_hat)^2)) + } + } + + cv_sd <- sqrt(apply(test_err, MARGIN = 2, var) / (nfold - 1)) + cv_mean <- colMeans(test_err) + idx_min <- which.min(cv_mean) + lambda_best <- lambda_list[idx_min] + + si_cv_results <- list( + cv_mean = cv_mean, cv_sd = cv_sd, + lambda_list = lambda_list, lambda_best = lambda_best, + test_err = test_err + ) + + return(si_cv_results) +} diff --git a/man/softImputeCV.Rd b/man/softImputeCV.Rd new file mode 100644 index 0000000..733eb0c --- /dev/null +++ b/man/softImputeCV.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/softImputeCV.R +\name{softImputeCV} +\alias{softImputeCV} +\title{Cross-validation for softImpute} +\usage{ +softImputeCV( + Y, + rank_max = NULL, + lambda_range = NULL, + nfold = 10, + para_length = 100 +) +} +\arguments{ +\item{Y}{A matrix. The main data matrix of N samples and M features.} + +\item{rank_max}{An integer. The maximum rank allowed.} + +\item{lambda_range}{A vector containing the minimal and maximal value of the parameter lambda.} + +\item{nfold}{An integer. The total number of validation sets.} + +\item{para_length}{An integer. The total number of parameter lambda.} +} +\value{ +A list containing the cross-validation results. +} +\description{ +Cross-validation for softImpute +} diff --git a/man/softImputeCV_sparse.Rd b/man/softImputeCV_sparse.Rd new file mode 100644 index 0000000..e47aa99 --- /dev/null +++ b/man/softImputeCV_sparse.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/softImputeCV.R +\name{softImputeCV_sparse} +\alias{softImputeCV_sparse} +\title{Sparse version of cross-validation for softImpute} +\usage{ +softImputeCV_sparse( + Y, + rank_max = NULL, + lambda_range = NULL, + nfold = 10, + para_length = 100 +) +} +\arguments{ +\item{Y}{A Matrix::dgCMatrix. The main data matrix of N samples and M features.} + +\item{rank_max}{An integer. The maximum rank allowed.} + +\item{lambda_range}{A vector containing the minimal and maximal value of the parameter lambda.} + +\item{nfold}{An integer. The total number of validation sets.} + +\item{para_length}{An integer. The total number of parameter lambda.} +} +\value{ +A list containing the cross-validation results. +} +\description{ +Sparse version of cross-validation for softImpute +}