diff --git a/STitch3D/model.py b/STitch3D/model.py index 5b60e8f..6187e06 100644 --- a/STitch3D/model.py +++ b/STitch3D/model.py @@ -20,6 +20,7 @@ def __init__(self, adata_st, adata_basis, training_steps=20000, lr=2e-3, seed=1234, + distribution="Poisson" ): self.training_steps = training_steps @@ -42,13 +43,22 @@ def __init__(self, adata_st, adata_basis, self.n_slices = len(sorted(set(adata_st.obs["slice"].values))) # build model - self.net = DeconvNet(hidden_dims=self.hidden_dims, - n_celltypes=self.n_celltype, - n_slices=self.n_slices, - n_heads=n_heads, - slice_emb_dim=slice_emb_dim, - coef_fe=coef_fe, - ).to(self.device) + if distribution == "Poisson": + self.net = DeconvNet(hidden_dims=self.hidden_dims, + n_celltypes=self.n_celltype, + n_slices=self.n_slices, + n_heads=n_heads, + slice_emb_dim=slice_emb_dim, + coef_fe=coef_fe, + ).to(self.device) + else: #Negative Binomial distribution + self.net = DeconvNet_NB(hidden_dims=self.hidden_dims, + n_celltypes=self.n_celltype, + n_slices=self.n_slices, + n_heads=n_heads, + slice_emb_dim=slice_emb_dim, + coef_fe=coef_fe, + ).to(self.device) self.optimizer = optim.Adamax(list(self.net.parameters()), lr=lr)