Skip to content

Commit

Permalink
Update single-cell spatial charting function
Browse files Browse the repository at this point in the history
  • Loading branch information
gefeiwang authored Jun 22, 2023
1 parent d71229b commit 803c056
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions STitch3D/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,104 @@ def eval(self, adata_st_list_raw, save=False, output_path="./results"):

return adata_st_decon_list


def cells_to_spatial(self, adata_ref_input,
celltype_ref_col="celltype", # column of adata_ref_input.obs for cell type information
celltype_ref=None, # specify cell types to use for deconvolution
target_num=20, # target number of cells per spot
save=False,
lam_sim=0.1,
lam_num=1e-3,
lam_M=1,
lr=2e-3, training_steps_M=20000, report_loss=True, step_interval=2000, output_path="./results"):

import scanpy as sc

# When map cells to spatial locations,
# the reference dataset needs to be processed in the same way as we used it to construct the cell-type matrix

adata_ref = adata_ref_input.copy()
adata_ref.var_names_make_unique()
# Remove mt-genes
adata_ref = adata_ref[:, np.array(~adata_ref.var.index.isna())
& np.array(~adata_ref.var_names.str.startswith("mt-"))
& np.array(~adata_ref.var_names.str.startswith("MT-"))]
if celltype_ref is not None:
if not isinstance(celltype_ref, list):
raise ValueError("'celltype_ref' must be a list!")
else:
adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :]
else:
celltype_counts = adata_ref.obs[celltype_ref_col].value_counts()
celltype_ref = list(celltype_counts.index[celltype_counts > 1])
adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :]

# Remove cells and genes with 0 counts
sc.pp.filter_cells(adata_ref, min_genes=1)
sc.pp.filter_genes(adata_ref, min_cells=1)

adata_ref = adata_ref[:, self.adata_st.var.index]

celltype_list = list(sorted(set(adata_ref.obs[celltype_ref_col].values.astype(str))))
if scipy.sparse.issparse(adata_ref.X):
ref_counts = adata_ref.X.toarray()
else:
ref_counts = adata_ref.X

# Generate count matrix for single cells
ref_counts = torch.from_numpy(ref_counts).to(torch.float32).to(self.device) # N_cells x G

celltype_onehot = np.zeros((adata_ref.shape[0], len(celltype_list)))
for i in range(adata_ref.shape[0]):
celltype_onehot[i, celltype_list.index(list(adata_ref.obs[celltype_ref_col].values)[i])] += 1.

# Generate one-hot cell-type matrix for single cells
celltype_onehot = torch.from_numpy(celltype_onehot).to(torch.float32).to(self.device) # N_cells x C

# Generate adjusted expression matrix for spatial spots
Y_adjusted = (torch.matmul(self.beta, self.basis) * self.lY).detach() # N_spots x G

beta = self.beta.detach() # N_spots x C

M = torch.zeros(adata_ref.shape[0], self.Y.shape[0]) # N_cells x N_spots
M = M.to(self.device)
M.requires_grad = True

self.optimizer_M = optim.Adamax([M], lr=lr)

for step in tqdm(range(training_steps_M)):
M_hat = F.softmax(M, dim=1) # N_cells x N_spots

generated_spots = torch.matmul(torch.transpose(M_hat, 0, 1), ref_counts) # N_spots x G
loss_sim_spots = - torch.mean(F.cosine_similarity(Y_adjusted, generated_spots, dim=1))
loss_sim_genes = - torch.mean(F.cosine_similarity(Y_adjusted, generated_spots, dim=0))

generated_spots_prop = torch.matmul(torch.transpose(M_hat, 0, 1), celltype_onehot) # N_spots x C
generated_spots_prop = generated_spots_prop / torch.sum(M_hat, axis=0).view(-1, 1) # Normalize generated proportions

loss_prop = torch.mean(torch.sum((generated_spots_prop - beta) ** 2, dim=1))

# regularizers
target_num = adata_ref.shape[0] / self.Y.shape[0]
reg_cell_num = torch.mean((torch.sum(M_hat, axis=0) - target_num)**2)
reg_M = -torch.mean(M_hat * torch.log(M_hat))
loss_M = loss_prop + lam_sim * (loss_sim_spots + loss_sim_genes) + lam_num * reg_cell_num + lam_M * reg_M
self.optimizer_M.zero_grad()
loss_M.backward()
self.optimizer_M.step()

if report_loss:
if not step % step_interval:
print("Step: %s, Loss: %.4f, proption_loss: %.4f, spot_sim_loss: %.4f, cell_num_reg: %.4f, M_reg: %.4f" %
(step, loss_M.item(), loss_prop.item(), (loss_sim_spots + loss_sim_genes).item(), reg_cell_num.item(), reg_M.item()))

M_hat = F.softmax(M, dim=1)
self.M_hat = M_hat.detach().cpu().numpy()

adata_ref.obsm['spatial_aligned'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obsm['spatial_aligned']
adata_ref.obsm['3D_coor'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obsm['3D_coor']
adata_ref.obs['slice'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obs['slice'].values

return adata_ref


0 comments on commit 803c056

Please sign in to comment.