Skip to content

Commit

Permalink
Update NB model option
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhao97 authored Jun 22, 2023
1 parent 9f59b98 commit ad989b5
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions STitch3D/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, adata_st, adata_basis,
training_steps=20000,
lr=2e-3,
seed=1234,
distribution="Poisson"
):

self.training_steps = training_steps
Expand All @@ -42,13 +43,22 @@ def __init__(self, adata_st, adata_basis,
self.n_slices = len(sorted(set(adata_st.obs["slice"].values)))

# build model
self.net = DeconvNet(hidden_dims=self.hidden_dims,
n_celltypes=self.n_celltype,
n_slices=self.n_slices,
n_heads=n_heads,
slice_emb_dim=slice_emb_dim,
coef_fe=coef_fe,
).to(self.device)
if distribution == "Poisson":
self.net = DeconvNet(hidden_dims=self.hidden_dims,
n_celltypes=self.n_celltype,
n_slices=self.n_slices,
n_heads=n_heads,
slice_emb_dim=slice_emb_dim,
coef_fe=coef_fe,
).to(self.device)
else: #Negative Binomial distribution
self.net = DeconvNet_NB(hidden_dims=self.hidden_dims,
n_celltypes=self.n_celltype,
n_slices=self.n_slices,
n_heads=n_heads,
slice_emb_dim=slice_emb_dim,
coef_fe=coef_fe,
).to(self.device)

self.optimizer = optim.Adamax(list(self.net.parameters()), lr=lr)

Expand Down

0 comments on commit ad989b5

Please sign in to comment.