From 3413105a83c266f550f04865e38a3c8a65ab05c8 Mon Sep 17 00:00:00 2001 From: Starlitnightly Date: Mon, 3 Jun 2024 16:28:31 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Drawing=203D=20plot=20keeping=20?= =?UTF-8?q?the=20original=20colours?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scSLAT/viz/multi_dataset.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/scSLAT/viz/multi_dataset.py b/scSLAT/viz/multi_dataset.py index f737f10..0b82b5e 100644 --- a/scSLAT/viz/multi_dataset.py +++ b/scSLAT/viz/multi_dataset.py @@ -90,6 +90,8 @@ def __init__(self,adatas:List[AnnData], self.loc_list.append(loc) self.anno_list.append(anno) + self.adatas = adatas + self.anno_key=anno_key self.celltypes = set(pd.concat(self.anno_list)) self.subsample_size = subsample_size @@ -130,6 +132,9 @@ def draw_3D(self, ax = fig.add_subplot(111, projection='3d') ax.set_box_aspect([1, 1, height_scale * len(self.mappings)]) # color by different cell types + + + color = get_color(len(self.celltypes)) c_map = {} for i, celltype in enumerate(self.celltypes): @@ -137,10 +142,21 @@ def draw_3D(self, for j, mapping in enumerate(self.mappings): print(f"Mapping {j}th layer ") # plot cells - for i, (layer, anno) in enumerate(zip(self.loc_list[j:j+2], self.anno_list[j:j+2])): + for i, (layer, anno,ad) in enumerate(zip(self.loc_list[j:j+2], self.anno_list[j:j+2],self.adatas[j:j+2])): if i==0 and 028: + c_map=dict(zip(ad.obs[self.anno_key].cat.categories,sc.pl.palettes.default_102)) + else: + c_map=dict(zip(ad.obs[self.anno_key].cat.categories,sc.pl.palettes.zeileis_28)) + + for cell_type in ad.obs[self.anno_key].cat.categories: slice = layer[anno == cell_type,:] xs = slice[:,0] ys = slice[:,1] @@ -159,7 +175,8 @@ def draw_3D(self, if hide_axis: plt.axis('off') - plt.show() + return ax + #plt.show() class match_3D_multi():