From 803c056aed67e89e4c041750d38099a6b2bc2669 Mon Sep 17 00:00:00 2001 From: WANG Gefei <46701885+gefeiwang@users.noreply.github.com> Date: Thu, 22 Jun 2023 10:44:43 +0800 Subject: [PATCH] Update single-cell spatial charting function --- STitch3D/model.py | 101 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/STitch3D/model.py b/STitch3D/model.py index 35a60a9..5b60e8f 100644 --- a/STitch3D/model.py +++ b/STitch3D/model.py @@ -121,3 +121,104 @@ def eval(self, adata_st_list_raw, save=False, output_path="./results"): return adata_st_decon_list + + def cells_to_spatial(self, adata_ref_input, + celltype_ref_col="celltype", # column of adata_ref_input.obs for cell type information + celltype_ref=None, # specify cell types to use for deconvolution + target_num=20, # target number of cells per spot + save=False, + lam_sim=0.1, + lam_num=1e-3, + lam_M=1, + lr=2e-3, training_steps_M=20000, report_loss=True, step_interval=2000, output_path="./results"): + + import scanpy as sc + + # When map cells to spatial locations, + # the reference dataset needs to be processed in the same way as we used it to construct the cell-type matrix + + adata_ref = adata_ref_input.copy() + adata_ref.var_names_make_unique() + # Remove mt-genes + adata_ref = adata_ref[:, np.array(~adata_ref.var.index.isna()) + & np.array(~adata_ref.var_names.str.startswith("mt-")) + & np.array(~adata_ref.var_names.str.startswith("MT-"))] + if celltype_ref is not None: + if not isinstance(celltype_ref, list): + raise ValueError("'celltype_ref' must be a list!") + else: + adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] + else: + celltype_counts = adata_ref.obs[celltype_ref_col].value_counts() + celltype_ref = list(celltype_counts.index[celltype_counts > 1]) + adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] + + # Remove cells and genes with 0 counts + sc.pp.filter_cells(adata_ref, min_genes=1) + sc.pp.filter_genes(adata_ref, min_cells=1) + + adata_ref = adata_ref[:, self.adata_st.var.index] + + celltype_list = list(sorted(set(adata_ref.obs[celltype_ref_col].values.astype(str)))) + if scipy.sparse.issparse(adata_ref.X): + ref_counts = adata_ref.X.toarray() + else: + ref_counts = adata_ref.X + + # Generate count matrix for single cells + ref_counts = torch.from_numpy(ref_counts).to(torch.float32).to(self.device) # N_cells x G + + celltype_onehot = np.zeros((adata_ref.shape[0], len(celltype_list))) + for i in range(adata_ref.shape[0]): + celltype_onehot[i, celltype_list.index(list(adata_ref.obs[celltype_ref_col].values)[i])] += 1. + + # Generate one-hot cell-type matrix for single cells + celltype_onehot = torch.from_numpy(celltype_onehot).to(torch.float32).to(self.device) # N_cells x C + + # Generate adjusted expression matrix for spatial spots + Y_adjusted = (torch.matmul(self.beta, self.basis) * self.lY).detach() # N_spots x G + + beta = self.beta.detach() # N_spots x C + + M = torch.zeros(adata_ref.shape[0], self.Y.shape[0]) # N_cells x N_spots + M = M.to(self.device) + M.requires_grad = True + + self.optimizer_M = optim.Adamax([M], lr=lr) + + for step in tqdm(range(training_steps_M)): + M_hat = F.softmax(M, dim=1) # N_cells x N_spots + + generated_spots = torch.matmul(torch.transpose(M_hat, 0, 1), ref_counts) # N_spots x G + loss_sim_spots = - torch.mean(F.cosine_similarity(Y_adjusted, generated_spots, dim=1)) + loss_sim_genes = - torch.mean(F.cosine_similarity(Y_adjusted, generated_spots, dim=0)) + + generated_spots_prop = torch.matmul(torch.transpose(M_hat, 0, 1), celltype_onehot) # N_spots x C + generated_spots_prop = generated_spots_prop / torch.sum(M_hat, axis=0).view(-1, 1) # Normalize generated proportions + + loss_prop = torch.mean(torch.sum((generated_spots_prop - beta) ** 2, dim=1)) + + # regularizers + target_num = adata_ref.shape[0] / self.Y.shape[0] + reg_cell_num = torch.mean((torch.sum(M_hat, axis=0) - target_num)**2) + reg_M = -torch.mean(M_hat * torch.log(M_hat)) + loss_M = loss_prop + lam_sim * (loss_sim_spots + loss_sim_genes) + lam_num * reg_cell_num + lam_M * reg_M + self.optimizer_M.zero_grad() + loss_M.backward() + self.optimizer_M.step() + + if report_loss: + if not step % step_interval: + print("Step: %s, Loss: %.4f, proption_loss: %.4f, spot_sim_loss: %.4f, cell_num_reg: %.4f, M_reg: %.4f" % + (step, loss_M.item(), loss_prop.item(), (loss_sim_spots + loss_sim_genes).item(), reg_cell_num.item(), reg_M.item())) + + M_hat = F.softmax(M, dim=1) + self.M_hat = M_hat.detach().cpu().numpy() + + adata_ref.obsm['spatial_aligned'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obsm['spatial_aligned'] + adata_ref.obsm['3D_coor'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obsm['3D_coor'] + adata_ref.obs['slice'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obs['slice'].values + + return adata_ref + +