Skip to content

Commit

Permalink
Update NB model option
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhao97 authored Jun 22, 2023
1 parent 803c056 commit 9f59b98
Showing 1 changed file with 99 additions and 0 deletions.
99 changes: 99 additions & 0 deletions STitch3D/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,102 @@ def deconvolutioner(self, Z, slice_label_emb):
H = F.elu(torch.cat((Z, slice_label_emb), axis=1))
alpha = self.deconv_alpha_layer(H)
return beta, alpha


class DeconvNet_NB(nn.Module):

def __init__(self,
hidden_dims, # dimensionality of hidden layers
n_celltypes, # number of cell types
n_slices, # number of slices
n_heads, # number of attention heads
slice_emb_dim, # dimensionality of slice id embedding
coef_fe,
):

super().__init__()

# define layers
# encoder layers
self.encoder_layer1 = GATMultiHead(hidden_dims[0], hidden_dims[1], n_heads=n_heads, concat_heads=True)
self.encoder_layer2 = DenseLayer(hidden_dims[1], hidden_dims[2])
# decoder layers
self.decoder_layer1 = GATMultiHead(hidden_dims[2] + slice_emb_dim, hidden_dims[1], n_heads=n_heads, concat_heads=True)
self.decoder_layer2 = DenseLayer(hidden_dims[1], hidden_dims[0])
# deconvolution layers
self.deconv_alpha_layer = DenseLayer(hidden_dims[2] + slice_emb_dim, 1, zero_init=True)
self.deconv_beta_layer = DenseLayer(hidden_dims[2], n_celltypes, zero_init=True)

self.gamma = nn.Parameter(torch.Tensor(n_slices, hidden_dims[0]).zero_())
self.logtheta = nn.Parameter(5. * torch.ones(n_slices, hidden_dims[0]))

self.slice_emb = nn.Embedding(n_slices, slice_emb_dim)

self.coef_fe = coef_fe

def forward(self,
adj_matrix, # adjacency matrix including self-connections
node_feats, # input node features
count_matrix, # gene expression counts
library_size, # library size (based on Y)
slice_label, # slice label
basis, # basis matrix
):
# encoder
Z = self.encoder(adj_matrix, node_feats)

# deconvolutioner
slice_label_emb = self.slice_emb(slice_label)
beta, alpha = self.deconvolutioner(Z, slice_label_emb)

# decoder
node_feats_recon = self.decoder(adj_matrix, Z, slice_label_emb)

# reconstruction loss of node features
self.features_loss = torch.mean(torch.sqrt(torch.sum(torch.pow(node_feats-node_feats_recon, 2), axis=1)))

# deconvolution loss
log_lam = torch.log(torch.matmul(beta, basis) + 1e-6) + alpha + self.gamma[slice_label]
lam = torch.exp(log_lam)
theta = torch.exp(self.logtheta)
self.decon_loss = - torch.mean(torch.sum(torch.lgamma(count_matrix + theta[slice_label] + 1e-6) -
torch.lgamma(theta[slice_label] + 1e-6) +
theta[slice_label] * torch.log(theta[slice_label] + 1e-6) -
theta[slice_label] * torch.log(theta[slice_label] + library_size * lam + 1e-6) +
count_matrix * torch.log(library_size * lam + 1e-6) -
count_matrix * torch.log(theta[slice_label] + library_size * lam + 1e-6), axis=1))

# Total loss
loss = self.decon_loss + self.coef_fe * self.features_loss

return loss

def evaluate(self, adj_matrix, node_feats, slice_label):
slice_label_emb = self.slice_emb(slice_label)
# encoder
Z = self.encoder(adj_matrix, node_feats)

# deconvolutioner
beta, alpha = self.deconvolutioner(Z, slice_label_emb)

return Z, beta, alpha, self.gamma

def encoder(self, adj_matrix, node_feats):
H = node_feats
H = F.elu(self.encoder_layer1(H, adj_matrix))
Z = self.encoder_layer2(H)
return Z

def decoder(self, adj_matrix, Z, slice_label_emb):
H = torch.cat((Z, slice_label_emb), axis=1)
H = F.elu(self.decoder_layer1(H, adj_matrix))
X_recon = self.decoder_layer2(H)
return X_recon

def deconvolutioner(self, Z, slice_label_emb):
beta = self.deconv_beta_layer(F.elu(Z))
beta = F.softmax(beta, dim=1)
H = F.elu(torch.cat((Z, slice_label_emb), axis=1))
alpha = self.deconv_alpha_layer(H)
return beta, alpha

0 comments on commit 9f59b98

Please sign in to comment.