-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add StarDist+RCTD and 6 benchmarking datasets
- Loading branch information
1 parent
9c0c8a9
commit 96b5b1a
Showing
9 changed files
with
17,108 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import anndata | ||
import numpy as np | ||
import pandas as pd | ||
import sys | ||
import pickle | ||
import os | ||
import copy | ||
import argparse | ||
from sklearn.model_selection import KFold | ||
from sklearn.metrics import mean_squared_error | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import scanpy as sc | ||
import warnings | ||
warnings.filterwarnings('ignore') | ||
import seaborn as sns | ||
|
||
from SDRCTD_utils import * | ||
|
||
|
||
|
||
|
||
class SDRCTD: | ||
def __init__(self,tissue,out_dir,RCTD_results_dir,RCTD_results_name, ST_Data, SC_Data, cell_class_column = 'cell_type', hs_ST = True): | ||
self.tissue = tissue | ||
self.out_dir = out_dir | ||
self.RCTD_results_dir = RCTD_results_dir | ||
self.RCTD_results_name = RCTD_results_name | ||
self.ST_Data = ST_Data | ||
self.SC_Data = SC_Data | ||
self.cell_class_column = cell_class_column | ||
self.hs_ST = hs_ST | ||
|
||
if not os.path.exists(out_dir): | ||
os.mkdir(out_dir) | ||
if not os.path.exists(os.path.join(out_dir,tissue)): | ||
os.mkdir(os.path.join(out_dir,tissue)) | ||
|
||
self.out_dir = os.path.join(out_dir,tissue) | ||
loggings = configure_logging(os.path.join(self.out_dir,'logs')) | ||
self.loggings = loggings | ||
|
||
self.LoadRCTDresults() | ||
if SC_Data is not None: | ||
self.LoadSCData() | ||
|
||
|
||
def LoadRCTDresults(self): | ||
with open(os.path.join(self.RCTD_results_dir, self.RCTD_results_name + '.pickle'), 'rb') as handle: | ||
RCTD_results = pickle.load(handle) | ||
|
||
if self.hs_ST: | ||
try: | ||
weights = RCTD_results['results']['weights'] | ||
except: | ||
weights = RCTD_results['results'] | ||
else: | ||
weights = RCTD_results['results'] | ||
|
||
|
||
self.weights = (weights / np.array(weights.sum(1))[:, None]) | ||
|
||
def single_cell_type_assignment(self, cell_num_column = 'cell_count', VisiumCellsPlot = True): | ||
seged_sp_adata = sc.read(self.ST_Data) #ST_Data already complete nuclei segmentation with StarDist. 'cell_locations' already in uns | ||
|
||
mat = self.weights.values | ||
cell_nums = np.array(seged_sp_adata.obs[cell_num_column]) | ||
|
||
cell_counts = distribute_cells(mat, cell_nums) | ||
cell_types = self.weights.columns | ||
cell_type_list = assign_cell_type(cell_counts, cell_types) | ||
seged_sp_adata.uns['cell_locations']['SDRCTD_cell_type'] = cell_type_list | ||
|
||
self.cell_type_list = cell_type_list | ||
self.seged_sp_adata = seged_sp_adata | ||
|
||
seged_sp_adata.uns['RCTD_weights'] = self.weights | ||
seged_sp_adata.write(os.path.join(self.out_dir, 'single_cell_type_label_bySDRCTD.h5ad')) | ||
|
||
# plot results | ||
if self.hs_ST or not VisiumCellsPlot: | ||
fig, ax = plt.subplots(figsize=(10,8.5),dpi=100) | ||
sns.scatterplot(data=seged_sp_adata.uns['cell_locations'], x="x",y="y",s=10,hue='SDRCTD_cell_type',palette='tab20',legend=True) | ||
plt.axis('off') | ||
plt.legend(bbox_to_anchor=(0.97, .98),framealpha=0) | ||
plt.savefig(os.path.join(self.out_dir, 'SDRCTD_estemated_ct_label.png')) | ||
plt.close() | ||
|
||
elif VisiumCellsPlot: | ||
if seged_sp_adata.obsm['spatial'].shape[1] == 2: | ||
fig, ax = plt.subplots(1,1,figsize=(14, 8),dpi=200) | ||
PlotVisiumCells(seged_sp_adata,"SDRCTD_cell_type",size=0.4,alpha_img=0.4,lw=0.4,palette='tab20',ax=ax) | ||
plt.savefig(os.path.join(self.out_dir, 'SDRCTD_estemated_ct_label.png')) | ||
plt.close() | ||
|
||
def cell_type_mean_assignment(self): | ||
# cell type mean as decomposed cell gene expression | ||
ref_df = pd.DataFrame([[ct, i]for i, ct in enumerate(self.sc_data_process_marker.obs[self.cell_class_column].astype('category').cat.categories)], columns = ['cell_type', 'cell_type_code']) | ||
ref_df.index = ref_df.cell_type | ||
ref_df = ref_df.iloc[:,1:] | ||
|
||
x_decom = self.mu[ref_df.loc[np.array(self.cell_type_list)].cell_type_code.tolist()] | ||
x_decom_adata = anndata.AnnData(X = x_decom.copy(), obs = self.seged_sp_adata.uns['cell_locations'].copy(), var = self.sc_data_process_marker.var) | ||
x_decom_adata.write(os.path.join(self.out_dir, 'cell_type_mean_bySDRCTD.h5ad')) | ||
|
||
|
||
def LoadSCData(self): | ||
# load sc data | ||
sc_data_process = anndata.read_h5ad(self.SC_Data) | ||
if 'Marker' in sc_data_process.var.columns: | ||
sc_data_process_marker = sc_data_process[:,sc_data_process.var['Marker']] | ||
else: | ||
sc_data_process_marker = sc_data_process | ||
|
||
if sc_data_process_marker.X.max() <= 30: | ||
self.loggings.info(f'Maximum value: {sc_data_process_marker.X.max()}, need to run exp') | ||
try: | ||
sc_data_process_marker.X = np.exp(sc_data_process_marker.X) - 1 | ||
except: | ||
sc_data_process_marker.X = np.exp(sc_data_process_marker.X.toarray()) - 1 | ||
|
||
|
||
cell_type_array = np.array(sc_data_process_marker.obs[self.cell_class_column]) | ||
cell_type_class = np.unique(cell_type_array) | ||
df_category = sc_data_process_marker.obs[[self.cell_class_column]].astype('category').apply(lambda x: x.cat.codes) | ||
|
||
# parameters: mean and cell type index | ||
cell_type_array_code = np.array(df_category[self.cell_class_column]) | ||
try: | ||
data = sc_data_process_marker.X.toarray() | ||
except: | ||
data = sc_data_process_marker.X | ||
|
||
n, d = data.shape | ||
q = cell_type_class.shape[0] | ||
self.loggings.info(f'scRNA-seq data shape: {data.shape}') | ||
self.loggings.info(f'scRNA-seq cell class number: {q}') | ||
|
||
mu = np.zeros((q, d)) | ||
for k in range(q): | ||
mu[k] = data[cell_type_array_code == k].mean(0).squeeze() | ||
self.mu = mu | ||
self.sc_data_process_marker = sc_data_process_marker | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
HEADER = """ | ||
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><> | ||
<> | ||
<> StarDist + RCTD | ||
<> Version: %s | ||
<> MIT License | ||
<> | ||
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><> | ||
<> Software-related correspondence: %s or %s | ||
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><> | ||
<> Visium data example: | ||
python <install path>/src/Cell_Type_Identification.py \\ | ||
--cell_class_column cell_type \\ | ||
--tissue heart \\ | ||
--out_dir ./output \\ | ||
--ST_Data ./output/heart/sp_adata_ns.h5ad \\ | ||
--SC_Data ./Ckpts_scRefs/Heart_D2/Ref_Heart_sanger_D2.h5ad | ||
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><> | ||
""" | ||
def str2bool(v): | ||
if isinstance(v, bool): | ||
return v | ||
if v.lower() in ("yes", "true", "t", "y", "1"): | ||
return True | ||
elif v.lower() in ("no", "false", "f", "n", "0"): | ||
return False | ||
else: | ||
raise argparse.ArgumentTypeError("Boolean value expected.") | ||
|
||
parser = argparse.ArgumentParser(description='simulation sour_sep') | ||
parser.add_argument('--out_dir', type=str, help='output path', default=None) | ||
parser.add_argument('--RCTD_results_dir', type=str, help='RCTD results path', default=None) | ||
parser.add_argument('--RCTD_results_name', type=str, help='RCTD results file\'s name', default='InitProp') | ||
parser.add_argument('--ST_Data', type=str, help='ST data path', default=None) | ||
parser.add_argument('--SC_Data', type=str, help='single cell reference data path', default=None) | ||
parser.add_argument('--cell_class_column', type=str, help='input cell class label column in scRef file', default = 'cell_type') | ||
parser.add_argument('--cell_num_column', type=str, help='cell number column in spatial file', default = 'cell_count') | ||
parser.add_argument('--hs_ST', action="store_true", help='high resolution ST data such as Slideseq, DBiT-seq, and HDST, MERFISH etc.') | ||
parser.add_argument("--VisiumCellsPlot", type=str2bool, const=True, default=True, nargs="?", help="whether to plot in VisiumCells mode or just scatter plot") | ||
args = parser.parse_args() | ||
|
||
args.tissue = 'SDRCTD_results' | ||
if not os.path.exists(args.out_dir): | ||
os.mkdir(args.out_dir) | ||
if not os.path.exists(os.path.join(args.out_dir,args.tissue)): | ||
os.mkdir(os.path.join(args.out_dir,args.tissue)) | ||
if args.RCTD_results_dir is None: | ||
args.RCTD_results_dir = args.out_dir | ||
|
||
|
||
sdr = SDRCTD(args.tissue,args.out_dir, args.RCTD_results_dir, args.RCTD_results_name, args.ST_Data, args.SC_Data, cell_class_column = args.cell_class_column, hs_ST = args.hs_ST) | ||
sdr.single_cell_type_assignment(cell_num_column = args.cell_num_column, VisiumCellsPlot = args.VisiumCellsPlot) | ||
if args.SC_Data is not None: | ||
sdr.cell_type_mean_assignment() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import os | ||
import scanpy as sc | ||
# import squidpy as sq | ||
import numpy as np | ||
import pandas as pd | ||
import pathlib | ||
import matplotlib.pyplot as plt | ||
import matplotlib as mpl | ||
# import skimage | ||
import seaborn as sns | ||
from itertools import chain | ||
# from stardist.models import StarDist2D | ||
from csbdeep.utils import normalize | ||
from anndata import AnnData | ||
from scipy.spatial.distance import pdist | ||
import logging | ||
import sys | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
def PlotVisiumCells(adata,annotation_list,size=0.8,alpha_img=0.3,lw=1,subset=None,palette='tab20',show_circle = True, legend = True, ax=None,**kwargs): | ||
merged_df = adata.uns['cell_locations'].copy() | ||
test = sc.AnnData(np.zeros(merged_df.shape), obs=merged_df) | ||
test.obsm['spatial'] = merged_df[["x", "y"]].to_numpy().copy() | ||
test.uns = adata.uns | ||
|
||
if subset is not None: | ||
#test = test[test.obs[annotation_list].isin(subset)] | ||
test.obs.loc[~test.obs[annotation_list].isin(subset),annotation_list] = None | ||
|
||
sc.pl.spatial( | ||
test, | ||
color=annotation_list, | ||
size=size, | ||
frameon=False, | ||
alpha_img=alpha_img, | ||
show=False, | ||
palette=palette, | ||
na_in_legend=False, | ||
ax=ax,title='',sort_order=True,**kwargs | ||
) | ||
if show_circle: | ||
sf = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['scalefactors']['tissue_hires_scalef'] | ||
spot_radius = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['scalefactors']['spot_diameter_fullres']/2 | ||
for sloc in adata.obsm['spatial']: | ||
rect = mpl.patches.Circle( | ||
(sloc[0] * sf, sloc[1] * sf), | ||
spot_radius * sf, | ||
ec="grey", | ||
lw=lw, | ||
fill=False | ||
) | ||
ax.add_patch(rect) | ||
ax.axes.xaxis.label.set_visible(False) | ||
ax.axes.yaxis.label.set_visible(False) | ||
|
||
if not legend: | ||
ax.get_legend().remove() | ||
|
||
# make frame visible | ||
for _, spine in ax.spines.items(): | ||
spine.set_visible(True) | ||
|
||
|
||
|
||
def assign_cell_type(cell_counts, cell_types): | ||
cell_type_list = [] | ||
for i in range(cell_counts.shape[0]): | ||
cell_count = cell_counts[i] | ||
idx = np.where(cell_count > 0)[0] | ||
cell_type_list_row = [[cell_types[idx][_]] * cell_count[idx[_]] for _ in range(idx.shape[0])] | ||
cell_type_list_row = np.array([item for sublist in cell_type_list_row for item in sublist]) | ||
np.random.shuffle(cell_type_list_row) | ||
cell_type_list = cell_type_list + list(cell_type_list_row) | ||
|
||
return cell_type_list | ||
|
||
def distribute_cells(mat, cell_nums): # mat: spots * cell_type; cell_nums: spots * 1 | ||
cell_nums = cell_nums.astype(int) | ||
cell_nums_original = cell_nums.copy() | ||
|
||
mat[np.absolute(mat) < 1e-3] = 0 | ||
cell_counts = np.zeros(mat.shape).astype(int) | ||
assert not np.any(cell_nums < 0) | ||
assert not np.any(mat < 0) | ||
|
||
mat = mat * cell_nums[:, None] | ||
cell_num_dist = np.floor(mat).astype(int) | ||
cell_counts = cell_counts + cell_num_dist | ||
cell_nums_remain = cell_nums - cell_num_dist.sum(1) | ||
mat_remain = mat - cell_num_dist | ||
|
||
assert not np.any(cell_nums_remain < 0) | ||
assert not np.any(mat_remain < 0) | ||
|
||
mat = mat_remain | ||
cell_nums = cell_nums_remain | ||
|
||
while(np.any(cell_nums_remain > 0)): | ||
mat[mat.argsort()[:, ::-1].argsort() >= cell_nums[:, None]] = 0 | ||
mat = np.divide(mat, mat.sum(1)[:,None], out=np.zeros_like(mat), where=mat.sum(1)[:,None]!=0) | ||
mat = mat * cell_nums[:, None] | ||
cell_num_dist = np.floor(mat).astype(int) | ||
cell_counts = cell_counts + cell_num_dist | ||
cell_nums_remain = cell_nums - cell_num_dist.sum(1) | ||
mat_remain = mat - cell_num_dist | ||
|
||
assert not np.any(cell_nums_remain < 0) | ||
assert not np.any(mat_remain < 0) | ||
|
||
mat = mat_remain | ||
cell_nums = cell_nums_remain | ||
|
||
assert np.array_equal(cell_counts.sum(1), cell_nums_original) | ||
|
||
return cell_counts | ||
|
||
|
||
|
||
|
||
def configure_logging(logger_name): | ||
LOG_LEVEL = logging.DEBUG | ||
log_filename = logger_name+'.log' | ||
importer_logger = logging.getLogger('importer_logger') | ||
importer_logger.setLevel(LOG_LEVEL) | ||
formatter = logging.Formatter('%(asctime)s : %(levelname)s : %(message)s') | ||
|
||
fh = logging.FileHandler(filename=log_filename) | ||
fh.setLevel(LOG_LEVEL) | ||
fh.setFormatter(formatter) | ||
importer_logger.addHandler(fh) | ||
|
||
sh = logging.StreamHandler(sys.stdout) | ||
sh.setLevel(LOG_LEVEL) | ||
sh.setFormatter(formatter) | ||
importer_logger.addHandler(sh) | ||
return importer_logger | ||
|
||
|
||
|
Oops, something went wrong.