Skip to content

Commit

Permalink
add StarDist+RCTD and 6 benchmarking datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaShun-Xiao committed Aug 28, 2023
1 parent 9c0c8a9 commit 96b5b1a
Show file tree
Hide file tree
Showing 9 changed files with 17,108 additions and 12 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ We provide source codes for reproducing the SpatialScope analysis in the main te

All relevent materials involved in the reproducing codes are availabel from [here](https://drive.google.com/drive/folders/1PXv_brtr-tXshBVEd_HSPIagjX9oF7Kg?usp=sharing)

+ [Benchmarking](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-MERFISH.ipynb)
+ [Benchmarking Dataset 1](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_1.ipynb)
+ [Benchmarking Dataset 2](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_2.ipynb)
+ [Benchmarking Dataset 3](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_3.ipynb)
+ [Benchmarking Dataset 4](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_4.ipynb)
+ [Benchmarking Dataset 5](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_5.ipynb)
+ [Benchmarking Dataset 6](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_6.ipynb)
+ [Human Heart (Visium, a single slice)](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Human-Heart.ipynb)
+ [Mouse Brain (Visium, 3D alignment of multiple slices)](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Mouse-Brain.ipynb)
+ [Mouse Cerebellum (Slideseq-V2)](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Mouse-Cerebellum-Slideseq.ipynb)
Expand Down
205 changes: 205 additions & 0 deletions compared_methods/SDRCTD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import anndata
import numpy as np
import pandas as pd
import sys
import pickle
import os
import copy
import argparse
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import warnings
warnings.filterwarnings('ignore')
import seaborn as sns

from SDRCTD_utils import *




class SDRCTD:
def __init__(self,tissue,out_dir,RCTD_results_dir,RCTD_results_name, ST_Data, SC_Data, cell_class_column = 'cell_type', hs_ST = True):
self.tissue = tissue
self.out_dir = out_dir
self.RCTD_results_dir = RCTD_results_dir
self.RCTD_results_name = RCTD_results_name
self.ST_Data = ST_Data
self.SC_Data = SC_Data
self.cell_class_column = cell_class_column
self.hs_ST = hs_ST

if not os.path.exists(out_dir):
os.mkdir(out_dir)
if not os.path.exists(os.path.join(out_dir,tissue)):
os.mkdir(os.path.join(out_dir,tissue))

self.out_dir = os.path.join(out_dir,tissue)
loggings = configure_logging(os.path.join(self.out_dir,'logs'))
self.loggings = loggings

self.LoadRCTDresults()
if SC_Data is not None:
self.LoadSCData()


def LoadRCTDresults(self):
with open(os.path.join(self.RCTD_results_dir, self.RCTD_results_name + '.pickle'), 'rb') as handle:
RCTD_results = pickle.load(handle)

if self.hs_ST:
try:
weights = RCTD_results['results']['weights']
except:
weights = RCTD_results['results']
else:
weights = RCTD_results['results']


self.weights = (weights / np.array(weights.sum(1))[:, None])

def single_cell_type_assignment(self, cell_num_column = 'cell_count', VisiumCellsPlot = True):
seged_sp_adata = sc.read(self.ST_Data) #ST_Data already complete nuclei segmentation with StarDist. 'cell_locations' already in uns

mat = self.weights.values
cell_nums = np.array(seged_sp_adata.obs[cell_num_column])

cell_counts = distribute_cells(mat, cell_nums)
cell_types = self.weights.columns
cell_type_list = assign_cell_type(cell_counts, cell_types)
seged_sp_adata.uns['cell_locations']['SDRCTD_cell_type'] = cell_type_list

self.cell_type_list = cell_type_list
self.seged_sp_adata = seged_sp_adata

seged_sp_adata.uns['RCTD_weights'] = self.weights
seged_sp_adata.write(os.path.join(self.out_dir, 'single_cell_type_label_bySDRCTD.h5ad'))

# plot results
if self.hs_ST or not VisiumCellsPlot:
fig, ax = plt.subplots(figsize=(10,8.5),dpi=100)
sns.scatterplot(data=seged_sp_adata.uns['cell_locations'], x="x",y="y",s=10,hue='SDRCTD_cell_type',palette='tab20',legend=True)
plt.axis('off')
plt.legend(bbox_to_anchor=(0.97, .98),framealpha=0)
plt.savefig(os.path.join(self.out_dir, 'SDRCTD_estemated_ct_label.png'))
plt.close()

elif VisiumCellsPlot:
if seged_sp_adata.obsm['spatial'].shape[1] == 2:
fig, ax = plt.subplots(1,1,figsize=(14, 8),dpi=200)
PlotVisiumCells(seged_sp_adata,"SDRCTD_cell_type",size=0.4,alpha_img=0.4,lw=0.4,palette='tab20',ax=ax)
plt.savefig(os.path.join(self.out_dir, 'SDRCTD_estemated_ct_label.png'))
plt.close()

def cell_type_mean_assignment(self):
# cell type mean as decomposed cell gene expression
ref_df = pd.DataFrame([[ct, i]for i, ct in enumerate(self.sc_data_process_marker.obs[self.cell_class_column].astype('category').cat.categories)], columns = ['cell_type', 'cell_type_code'])
ref_df.index = ref_df.cell_type
ref_df = ref_df.iloc[:,1:]

x_decom = self.mu[ref_df.loc[np.array(self.cell_type_list)].cell_type_code.tolist()]
x_decom_adata = anndata.AnnData(X = x_decom.copy(), obs = self.seged_sp_adata.uns['cell_locations'].copy(), var = self.sc_data_process_marker.var)
x_decom_adata.write(os.path.join(self.out_dir, 'cell_type_mean_bySDRCTD.h5ad'))


def LoadSCData(self):
# load sc data
sc_data_process = anndata.read_h5ad(self.SC_Data)
if 'Marker' in sc_data_process.var.columns:
sc_data_process_marker = sc_data_process[:,sc_data_process.var['Marker']]
else:
sc_data_process_marker = sc_data_process

if sc_data_process_marker.X.max() <= 30:
self.loggings.info(f'Maximum value: {sc_data_process_marker.X.max()}, need to run exp')
try:
sc_data_process_marker.X = np.exp(sc_data_process_marker.X) - 1
except:
sc_data_process_marker.X = np.exp(sc_data_process_marker.X.toarray()) - 1


cell_type_array = np.array(sc_data_process_marker.obs[self.cell_class_column])
cell_type_class = np.unique(cell_type_array)
df_category = sc_data_process_marker.obs[[self.cell_class_column]].astype('category').apply(lambda x: x.cat.codes)

# parameters: mean and cell type index
cell_type_array_code = np.array(df_category[self.cell_class_column])
try:
data = sc_data_process_marker.X.toarray()
except:
data = sc_data_process_marker.X

n, d = data.shape
q = cell_type_class.shape[0]
self.loggings.info(f'scRNA-seq data shape: {data.shape}')
self.loggings.info(f'scRNA-seq cell class number: {q}')

mu = np.zeros((q, d))
for k in range(q):
mu[k] = data[cell_type_array_code == k].mean(0).squeeze()
self.mu = mu
self.sc_data_process_marker = sc_data_process_marker







if __name__ == "__main__":
HEADER = """
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
<>
<> StarDist + RCTD
<> Version: %s
<> MIT License
<>
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
<> Software-related correspondence: %s or %s
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
<> Visium data example:
python <install path>/src/Cell_Type_Identification.py \\
--cell_class_column cell_type \\
--tissue heart \\
--out_dir ./output \\
--ST_Data ./output/heart/sp_adata_ns.h5ad \\
--SC_Data ./Ckpts_scRefs/Heart_D2/Ref_Heart_sanger_D2.h5ad
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
"""
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")

parser = argparse.ArgumentParser(description='simulation sour_sep')
parser.add_argument('--out_dir', type=str, help='output path', default=None)
parser.add_argument('--RCTD_results_dir', type=str, help='RCTD results path', default=None)
parser.add_argument('--RCTD_results_name', type=str, help='RCTD results file\'s name', default='InitProp')
parser.add_argument('--ST_Data', type=str, help='ST data path', default=None)
parser.add_argument('--SC_Data', type=str, help='single cell reference data path', default=None)
parser.add_argument('--cell_class_column', type=str, help='input cell class label column in scRef file', default = 'cell_type')
parser.add_argument('--cell_num_column', type=str, help='cell number column in spatial file', default = 'cell_count')
parser.add_argument('--hs_ST', action="store_true", help='high resolution ST data such as Slideseq, DBiT-seq, and HDST, MERFISH etc.')
parser.add_argument("--VisiumCellsPlot", type=str2bool, const=True, default=True, nargs="?", help="whether to plot in VisiumCells mode or just scatter plot")
args = parser.parse_args()

args.tissue = 'SDRCTD_results'
if not os.path.exists(args.out_dir):
os.mkdir(args.out_dir)
if not os.path.exists(os.path.join(args.out_dir,args.tissue)):
os.mkdir(os.path.join(args.out_dir,args.tissue))
if args.RCTD_results_dir is None:
args.RCTD_results_dir = args.out_dir


sdr = SDRCTD(args.tissue,args.out_dir, args.RCTD_results_dir, args.RCTD_results_name, args.ST_Data, args.SC_Data, cell_class_column = args.cell_class_column, hs_ST = args.hs_ST)
sdr.single_cell_type_assignment(cell_num_column = args.cell_num_column, VisiumCellsPlot = args.VisiumCellsPlot)
if args.SC_Data is not None:
sdr.cell_type_mean_assignment()
139 changes: 139 additions & 0 deletions compared_methods/SDRCTD_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import scanpy as sc
# import squidpy as sq
import numpy as np
import pandas as pd
import pathlib
import matplotlib.pyplot as plt
import matplotlib as mpl
# import skimage
import seaborn as sns
from itertools import chain
# from stardist.models import StarDist2D
from csbdeep.utils import normalize
from anndata import AnnData
from scipy.spatial.distance import pdist
import logging
import sys
from sklearn.metrics.pairwise import cosine_similarity

def PlotVisiumCells(adata,annotation_list,size=0.8,alpha_img=0.3,lw=1,subset=None,palette='tab20',show_circle = True, legend = True, ax=None,**kwargs):
merged_df = adata.uns['cell_locations'].copy()
test = sc.AnnData(np.zeros(merged_df.shape), obs=merged_df)
test.obsm['spatial'] = merged_df[["x", "y"]].to_numpy().copy()
test.uns = adata.uns

if subset is not None:
#test = test[test.obs[annotation_list].isin(subset)]
test.obs.loc[~test.obs[annotation_list].isin(subset),annotation_list] = None

sc.pl.spatial(
test,
color=annotation_list,
size=size,
frameon=False,
alpha_img=alpha_img,
show=False,
palette=palette,
na_in_legend=False,
ax=ax,title='',sort_order=True,**kwargs
)
if show_circle:
sf = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['scalefactors']['tissue_hires_scalef']
spot_radius = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['scalefactors']['spot_diameter_fullres']/2
for sloc in adata.obsm['spatial']:
rect = mpl.patches.Circle(
(sloc[0] * sf, sloc[1] * sf),
spot_radius * sf,
ec="grey",
lw=lw,
fill=False
)
ax.add_patch(rect)
ax.axes.xaxis.label.set_visible(False)
ax.axes.yaxis.label.set_visible(False)

if not legend:
ax.get_legend().remove()

# make frame visible
for _, spine in ax.spines.items():
spine.set_visible(True)



def assign_cell_type(cell_counts, cell_types):
cell_type_list = []
for i in range(cell_counts.shape[0]):
cell_count = cell_counts[i]
idx = np.where(cell_count > 0)[0]
cell_type_list_row = [[cell_types[idx][_]] * cell_count[idx[_]] for _ in range(idx.shape[0])]
cell_type_list_row = np.array([item for sublist in cell_type_list_row for item in sublist])
np.random.shuffle(cell_type_list_row)
cell_type_list = cell_type_list + list(cell_type_list_row)

return cell_type_list

def distribute_cells(mat, cell_nums): # mat: spots * cell_type; cell_nums: spots * 1
cell_nums = cell_nums.astype(int)
cell_nums_original = cell_nums.copy()

mat[np.absolute(mat) < 1e-3] = 0
cell_counts = np.zeros(mat.shape).astype(int)
assert not np.any(cell_nums < 0)
assert not np.any(mat < 0)

mat = mat * cell_nums[:, None]
cell_num_dist = np.floor(mat).astype(int)
cell_counts = cell_counts + cell_num_dist
cell_nums_remain = cell_nums - cell_num_dist.sum(1)
mat_remain = mat - cell_num_dist

assert not np.any(cell_nums_remain < 0)
assert not np.any(mat_remain < 0)

mat = mat_remain
cell_nums = cell_nums_remain

while(np.any(cell_nums_remain > 0)):
mat[mat.argsort()[:, ::-1].argsort() >= cell_nums[:, None]] = 0
mat = np.divide(mat, mat.sum(1)[:,None], out=np.zeros_like(mat), where=mat.sum(1)[:,None]!=0)
mat = mat * cell_nums[:, None]
cell_num_dist = np.floor(mat).astype(int)
cell_counts = cell_counts + cell_num_dist
cell_nums_remain = cell_nums - cell_num_dist.sum(1)
mat_remain = mat - cell_num_dist

assert not np.any(cell_nums_remain < 0)
assert not np.any(mat_remain < 0)

mat = mat_remain
cell_nums = cell_nums_remain

assert np.array_equal(cell_counts.sum(1), cell_nums_original)

return cell_counts




def configure_logging(logger_name):
LOG_LEVEL = logging.DEBUG
log_filename = logger_name+'.log'
importer_logger = logging.getLogger('importer_logger')
importer_logger.setLevel(LOG_LEVEL)
formatter = logging.Formatter('%(asctime)s : %(levelname)s : %(message)s')

fh = logging.FileHandler(filename=log_filename)
fh.setLevel(LOG_LEVEL)
fh.setFormatter(formatter)
importer_logger.addHandler(fh)

sh = logging.StreamHandler(sys.stdout)
sh.setLevel(LOG_LEVEL)
sh.setFormatter(formatter)
importer_logger.addHandler(sh)
return importer_logger



Loading

0 comments on commit 96b5b1a

Please sign in to comment.