Skip to content

Commit

Permalink
Merge pull request #55 from brianjohnhaas/negbinom
Browse files Browse the repository at this point in the history
Use negative binomial for simulating spike-in

Former-commit-id: ef281f5
Former-commit-id: 4f948b5
  • Loading branch information
GeorgescuC authored Nov 7, 2018
2 parents df14e81 + 713ae63 commit a2059fd
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 53 deletions.
2 changes: 1 addition & 1 deletion R/inferCNV_heatmap.R
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ plot_cnv <- function(infercnv_obj,
observation_file_base,
sep=" "))
row.names(obs_data) <- orig_row_names
write.table(obs_data[data_observations$rowInd,data_observations$colInd],
write.table(t(obs_data[data_observations$rowInd,data_observations$colInd]),
file=observation_file_base)
}
}
Expand Down
55 changes: 33 additions & 22 deletions R/inferCNV_ops.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
#'
#' @param include.spike If true, introduces an artificial spike-in of data at ~0x and 2x for scaling residuals between 0-2. (default: F)
#'
#' @param spike_in_chrs vector listing of chr names to use for modeling spike-ins (default: NULL - uses the two largest chrs. ex. c('chr1', 'chr2') )
#'
#' @param spike_in_multiplier vector of weights matching spike_in_chrs (default: c(0.01, 2.0) for modeling loss/gain of both chrs)
#'
#' @param pseudocount Number of counts to add to each gene of each cell post-filtering of genes and cells and pre-total sum count normalization. (default: 0)
#'
#' @param debug If true, output debug level logging.
Expand Down Expand Up @@ -107,7 +111,7 @@ run <- function(infercnv_obj,
use_zscores=FALSE,
remove_genes_at_chr_ends=FALSE,

mask_nonDE_genes=TRUE,
mask_nonDE_genes=FALSE,
mask_nonDE_pval=0.05,
test.use='wilcoxon',

Expand All @@ -116,7 +120,11 @@ run <- function(infercnv_obj,
debug=FALSE, #for debug level logging

include.spike = FALSE,


# must specify both below if to be used, and must match in vec length
spike_in_chrs = NULL, # use defaults
spike_in_multiplier_vec = NULL, # use defaults

pseudocount = 0

) {
Expand Down Expand Up @@ -202,10 +210,14 @@ run <- function(infercnv_obj,
if (include.spike) {
step_count = step_count + 1
flog.info(sprintf("\n\n\tSTEP %02d: Spiking in genes with variation added for tracking\n", step_count))

if (! (is.null(spike_in_chrs) && is.null(spike_in_multiplier_vec)) ) {
infercnv_obj <- spike_in_variation_chrs(infercnv_obj, spike_in_chrs, spike_in_multiplier_vec)
} else {
infercnv_obj <- spike_in_variation_chrs(infercnv_obj)
}

infercnv_obj <- spike_in_variation_chrs(infercnv_obj)

# Plot incremental steps.
# Plot incremental steps.
if (plot_steps){

infercnv_obj_spiked <- infercnv_obj
Expand Down Expand Up @@ -657,9 +669,6 @@ run <- function(infercnv_obj,
output_filename=sprintf("infercnv.%02d_scaled_by_spike", step_count))
}

# remove the spike now
infercnv_obj <- remove_spike(infercnv_obj)

}


Expand Down Expand Up @@ -697,6 +706,12 @@ run <- function(infercnv_obj,
}
}

if (include.spike) {
# remove the spike before making the final plot.
infercnv_obj <- remove_spike(infercnv_obj)
}


save('infercnv_obj', file=file.path(out_dir, "run.final.infercnv_obj"))

flog.info("Making the final infercnv heatmap")
Expand Down Expand Up @@ -1143,7 +1158,7 @@ center_cell_expr_across_chromosome <- function(infercnv_obj, method="mean") { #

#' @title require_above_min_mean_expr_cutoff ()
#'
#' @description Filters out genes that have fewer than the corresponding mean value across the reference cell values.
#' @description Filters out genes that have fewer than the corresponding mean value across all cell values.
#'
#' @param infercnv_obj infercnv_object
#'
Expand All @@ -1158,10 +1173,8 @@ require_above_min_mean_expr_cutoff <- function(infercnv_obj, min_mean_expr_cutof

flog.info(paste("::above_min_mean_expr_cutoff:Start", sep=""))

# restrict to reference cells:
ref_cells_data <- infercnv_obj@expr.data[ , get_reference_grouped_cell_indices(infercnv_obj) ]

indices <-.below_min_mean_expr_cutoff(ref_cells_data, min_mean_expr_cutoff)
indices <-.below_min_mean_expr_cutoff(infercnv_obj@expr.data, min_mean_expr_cutoff)
if (length(indices) > 0) {
flog.info(sprintf("Removing %d genes from matrix as below mean expr threshold: %g",
length(indices), min_mean_expr_cutoff))
Expand Down Expand Up @@ -1195,7 +1208,7 @@ require_above_min_mean_expr_cutoff <- function(infercnv_obj, min_mean_expr_cutof

#' @title require_above_min_cells_ref()
#'
#' @description Filters out genes that have fewer than specified number of reference cells expressing them.
#' @description Filters out genes that have fewer than specified number of cells expressing them.
#'
#' @param infercnv_obj infercnv_object
#'
Expand All @@ -1207,15 +1220,11 @@ require_above_min_mean_expr_cutoff <- function(infercnv_obj, min_mean_expr_cutof
#'

require_above_min_cells_ref <- function(infercnv_obj, min_cells_per_gene) {

ref_cell_indices = get_reference_grouped_cell_indices(infercnv_obj)

ref_data = infercnv_obj@expr.data[,ref_cell_indices]

ref_genes_passed = which(apply(ref_data, 1, function(x) { sum(x>0 & ! is.na(x)) >= min_cells_per_gene}))
genes_passed = which(apply(infercnv_obj@expr.data, 1, function(x) { sum(x>0 & ! is.na(x)) >= min_cells_per_gene}))

num_genes_total = dim(ref_data)[1]
num_removed = num_genes_total - length(ref_genes_passed)
num_genes_total = dim(infercnv_obj@expr.data)[1]
num_removed = num_genes_total - length(genes_passed)
if (num_removed > 0) {

flog.info(sprintf("Removed %d genes having fewer than %d min cells per gene = %g %% genes removed here",
Expand All @@ -1229,7 +1238,7 @@ require_above_min_cells_ref <- function(infercnv_obj, min_cells_per_gene) {
}


infercnv_obj <- remove_genes(infercnv_obj, -1 * ref_genes_passed)
infercnv_obj <- remove_genes(infercnv_obj, -1 * genes_passed)


}
Expand Down Expand Up @@ -1904,7 +1913,9 @@ anscombe_transform <- function(infercnv_obj) {

}


#' @keywords internal
#' @noRd
#'
add_pseudocount <- function(infercnv_obj, pseudocount) {

flog.info(sprintf("Adding pseudocount: %g", pseudocount))
Expand Down
184 changes: 156 additions & 28 deletions R/inferCNV_spike.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,22 @@ spike_in_variation_chrs <- function(infercnv_obj,
normal_cells_idx = infercnv::get_reference_grouped_cell_indices(infercnv_obj)
normal_cells_expr = infercnv_obj@expr.data[,normal_cells_idx]

# zeros are a problem here...
gene_means = rowMeans(normal_cells_expr)

mean_p0_table = .get_mean_vs_p0_table(infercnv_obj)

## apply spike-in multiplier vec
for (i in 1:length(spike_in_multiplier_vec)) {

gene_indices = gene_selection_listing[[i]]
multiplier = spike_in_multiplier_vec[i]

normal_cells_expr[gene_indices, ] = normal_cells_expr[gene_indices, ] * multiplier

}

gene_means[gene_indices] = gene_means[gene_indices] * multiplier
}

## get simulated matrix
sim_matrix = .get_simulated_cell_matrix(mvtable, normal_cells_expr, max_cells)
sim_matrix = .get_simulated_cell_matrix(gene_means, mean_p0_table, max_cells)

## integrate into expr data and count data matrices
ncol_begin = ncol(infercnv_obj@expr.data) + 1
Expand All @@ -114,11 +118,9 @@ spike_in_variation_chrs <- function(infercnv_obj,
##' the mean/variance relationship for all genes in all cell groupings.
##'
##' Cells are simulated as so:
##' A random cell is selected from the normal cell expression matrix.
##' The expression of each gene is treated as a targeted mean expression value.
##' The variance is chosen based on a spline fit to the mean/variance relationship provided.
##' A random expression value is generated from a normal distribution with corresponding (mean, variance)
##'
##' The mean for genes in the normal cells are computed
##' A random expression value is chosen for each gene using a negative binomial distribution with dispersion = 0.1
##'
##' Genes are named according to the input expression matrix, and cells are named 'spike_{number}'.
##'
##' @param mean_var_table : a data.frame containing three columns: group_name, mean, variance of expression per gene per grouping.
Expand All @@ -133,38 +135,58 @@ spike_in_variation_chrs <- function(infercnv_obj,
##' @noRd
##'

.get_simulated_cell_matrix <- function(mean_var_table, normal_cell_expr, num_cells) {
.get_simulated_cell_matrix <- function(gene_means, mean_p0_table, num_cells) {

# should be working on the total sum count normalized data.
# model the mean variance relationship

ngenes = length(gene_means)

dropout_logistic_params <- .get_logistic_params(mean_p0_table)

s = smooth.spline(log2(mean_var_table$m+1), log2(mean_var_table$v+1))

spike_cell_names = paste0("spike_", 1:num_cells)

ngenes = nrow(normal_cell_expr)
spike_cell_names = paste0('spike_cell_', 1:num_cells)

sim_expr_val <- function(gene_idx, rand_cell_idx) {
m = normal_cell_expr[gene_idx, rand_cell_idx]
v = predict(s, log2(m+1))$y
v = max(0, 2^v-1)
val = max(0, rnorm(n=1, mean=m, sd=sqrt(v)))
return(val)
}

sim_cell_matrix = matrix(rep(0,ngenes*num_cells), nrow=ngenes)
rownames(sim_cell_matrix) = rownames(normal_cell_expr)
rownames(sim_cell_matrix) = names(gene_means)
colnames(sim_cell_matrix) = spike_cell_names

sim_expr_vals <- function(gene_idx) {
m = gene_means[gene_idx]
return(.sim_expr_val(m, dropout_logistic_params))
}

for (i in 1:num_cells) {
rand_cell_idx = floor(runif(1) * ncol(normal_cell_expr)+1)
newvals = sapply(1:ngenes, FUN=sim_expr_val, rand_cell_idx)
newvals = sapply(1:ngenes, FUN=sim_expr_vals)
sim_cell_matrix[,i] = newvals
}

return(sim_cell_matrix)
}

##' @keywords internal
##' @noRd
##'

.sim_expr_val <- function(m, dropout_logistic_params) {

# include drop-out prediction

val = 0
if (m > 0) {
dropout_prob <- .logistic(x=log(m), midpt=dropout_logistic_params$midpt, slope=dropout_logistic_params$slope)

if (runif(1) > dropout_prob) {
# not a drop-out
val = rnbinom(n=1, mu=m, size=1/0.1) #fixed dispersion at 0.1
}
}
return(val)
}




##' .get_mean_var_table()
##'
##' Computes the gene mean/variance table based on all defined cell groupings (reference and observations)
Expand Down Expand Up @@ -199,7 +221,7 @@ spike_in_variation_chrs <- function(infercnv_obj,
return(mean_var_table)
}

##' get_spike_in_average_bounds()
##' .get_spike_in_average_bounds()
##'
##' return mean bounds for expression of all cells in the spike-in
##'
Expand Down Expand Up @@ -292,7 +314,11 @@ scale_cnv_by_spike <- function(infercnv_obj) {
}


# selects the specified number of chrs having the largest number of (expressed) genes
#' selects the specified number of chrs having the largest number of (expressed) genes
#' @keywords internal
#' @noRd
#'

.select_longest_chrs <- function(infercnv_obj, num_chrs_want) {

# get count of chrs
Expand All @@ -301,3 +327,105 @@ scale_cnv_by_spike <- function(infercnv_obj) {
return(counts$chr[1:num_chrs_want])

}

#' Computes probability of seeing a zero expr val as a function of the mean gene expression
#' The p(0 | mean_expr) is computed separately for each sample grouping.
#'
#' @keywords internal
#' @noRd
#'

.get_mean_vs_p0_table <- function(infercnv_obj) {

group_indices = c(infercnv_obj@observation_grouped_cell_indices, infercnv_obj@reference_grouped_cell_indices)

mean_p0_table = NULL

for (group_name in names(group_indices)) {
flog.info(sprintf("processing group: %s", group_name))
expr.data = infercnv_obj@expr.data[, group_indices[[ group_name ]] ]

group_mean_p0_table <- .get_mean_vs_p0_from_matrix(expr.data)
group_mean_p0_table[[ 'group_name' ]] <- group_name

if (is.null(mean_p0_table)) {
mean_p0_table = group_mean_p0_table
} else {
mean_p0_table = rbind(mean_p0_table, group_mean_p0_table)
}
}

return(mean_p0_table)
}

#' Computes probability of seeing a zero expr val as a function of the mean gene expression
#' based on the input expression matrix.
#'
#' @keywords internal
#' @noRd
#'

.get_mean_vs_p0_from_matrix <- function(expr.data) {
ncells = ncol(expr.data)
m = rowMeans(expr.data)
numZeros = apply(expr.data, 1, function(x) { sum(x==0) })

pZero = numZeros/ncells

mean_p0_table = data.frame(m=m, p0=pZero)

return(mean_p0_table)
}


#'
#' Logistic function
#'
#' InferCNV note: Standard function here, but lifted from
#' Splatter (Zappia, Phipson, and Oshlack, 2017)
#' https://genomebiology.biomedcentral.com/articles/10.1186/s13059-017-1305-0
#'
#' Implementation of the logistic function
#'
#' @param x value to apply the function to.
#' @param x0 midpoint parameter. Gives the centre of the function.
#' @param k shape parameter. Gives the slope of the function.
#'
#' @return Value of logistic function with given parameters
#'
#' @keywords internal
#' @noRd
#'
.logistic <- function(x, midpt, slope) {
1 / (1 + exp(-slope * (x - midpt)))
}



#' Given the mean, p0 table, fits the data to a logistic function to compute
#' the shape of the logistic distribution.
#'
#' @keywords internal
#' @noRd
#'

.get_logistic_params <- function(mean_p0_table) {

mean_p0_table <- mean_p0_table[mean_p0_table$m > 0, ] # remove zeros, can't take log.

x = log(mean_p0_table$m)
y = mean_p0_table$p0

df = data.frame(x,y)

#write.table(df, "_logistic_params", quote=F, sep="\t") # debugging...

fit <- nls(y ~ .logistic(x, midpt = x0, slope = k), data = df, start = list(x0 = mean(x), k = -1)) # borrowed/updated from splatter

logistic_params = list()

logistic_params[[ 'midpt' ]] <- summary(fit)$coefficients["x0", "Estimate"]
logistic_params[[ 'slope' ]] <- summary(fit)$coefficients["k", "Estimate"]

return(logistic_params)
}
Loading

0 comments on commit a2059fd

Please sign in to comment.