-
Notifications
You must be signed in to change notification settings - Fork 1
/
NEST_Visium_HD_segmentation.py
440 lines (334 loc) · 17.8 KB
/
NEST_Visium_HD_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
# adapted from the Visium
print('Package loading ...')
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata
import geopandas as gpd
import scanpy as sc
from tifffile import imread, imwrite
from csbdeep.utils import normalize
from stardist.models import StarDist2D
from shapely.geometry import Polygon, Point
from scipy import sparse
from matplotlib.colors import ListedColormap
import argparse
import os
from collections import defaultdict
from shapely import MultiPoint, centroid
import gzip
import pickle
import altair as alt
import altairThemes
alt.themes.register("publishTheme", altairThemes.publishTheme)
# enable the newly registered theme
alt.themes.enable("publishTheme")
#matplotlib inline
#config InlineBackend.figure_format = 'retina'
# General image plotting functions
def plot_mask_and_save_image(title, gdf, img, cmap, output_name=None, bbox=None):
if bbox is not None:
# Crop the image to the bounding box
cropped_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
else:
cropped_img = img
# Plot options
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Plot the cropped image
axes[0].imshow(cropped_img, cmap='gray', origin='lower')
axes[0].set_title(title)
axes[0].axis('off')
# Create filtering polygon
if bbox is not None:
bbox_polygon = Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])])
# Filter for polygons in the box
intersects_bbox = gdf['geometry'].intersects(bbox_polygon)
filtered_gdf = gdf[intersects_bbox]
else:
filtered_gdf=gdf
# Plot the filtered polygons on the second axis
filtered_gdf.plot(cmap=cmap, ax=axes[1])
axes[1].axis('off')
axes[1].legend(loc='upper left', bbox_to_anchor=(1.05, 1))
# Save the plot if output_name is provided
if output_name is not None:
plt.savefig(output_name, bbox_inches='tight') # Use bbox_inches='tight' to include the legend
else:
plt.show()
def plot_gene_and_save_image(title, gdf, gene, img, adata, bbox=None, output_name=None):
if bbox is not None:
# Crop the image to the bounding box
cropped_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
else:
cropped_img = img
# Plot options
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Plot the cropped image
axes[0].imshow(cropped_img, cmap='gray', origin='lower')
axes[0].set_title(title)
axes[0].axis('off')
# Create filtering polygon
if bbox is not None:
bbox_polygon = Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])])
# Find a gene of interest and merge with the geodataframe
gene_expression = adata[:, gene].to_df()
gene_expression['id'] = gene_expression.index
merged_gdf = gdf.merge(gene_expression, left_on='id', right_on='id')
if bbox is not None:
# Filter for polygons in the box
intersects_bbox = merged_gdf['geometry'].intersects(bbox_polygon)
filtered_gdf = merged_gdf[intersects_bbox]
else:
filtered_gdf = merged_gdf
# Plot the filtered polygons on the second axis
filtered_gdf.plot(column=gene, cmap='inferno', legend=True, ax=axes[1])
axes[1].set_title(gene)
axes[1].axis('off')
axes[1].legend(loc='upper left', bbox_to_anchor=(1.05, 1))
# Save the plot if output_name is provided
if output_name is not None:
plt.savefig(output_name, bbox_inches='tight') # Use bbox_inches='tight' to include the legend
else:
plt.show()
def plot_clusters_and_save_image(title, gdf, img, adata, bbox=None, color_by_obs=None, output_name=None, color_list=None):
color_list=["#7f0000","#808000","#483d8b","#008000","#bc8f8f","#008b8b","#4682b4","#000080","#d2691e","#9acd32","#8fbc8f","#800080","#b03060","#ff4500","#ffa500","#ffff00","#00ff00","#8a2be2","#00ff7f","#dc143c","#00ffff","#0000ff","#ff00ff","#1e90ff","#f0e68c","#90ee90","#add8e6","#ff1493","#7b68ee","#ee82ee"]
if bbox is not None:
cropped_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
else:
cropped_img = img
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(cropped_img, cmap='gray', origin='lower')
axes[0].set_title(title)
axes[0].axis('off')
if bbox is not None:
bbox_polygon = Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])])
unique_values = adata.obs[color_by_obs].astype('category').cat.categories
num_categories = len(unique_values)
if color_list is not None and len(color_list) >= num_categories:
custom_cmap = ListedColormap(color_list[:num_categories], name='custom_cmap')
else:
# Use default tab20 colors if color_list is insufficient
tab20_colors = plt.cm.tab20.colors[:num_categories]
custom_cmap = ListedColormap(tab20_colors, name='custom_tab20_cmap')
merged_gdf = gdf.merge(adata.obs[color_by_obs].astype('category'), left_on='id', right_index=True)
if bbox is not None:
intersects_bbox = merged_gdf['geometry'].intersects(bbox_polygon)
filtered_gdf = merged_gdf[intersects_bbox]
else:
filtered_gdf = merged_gdf
# Plot the filtered polygons on the second axis
plot = filtered_gdf.plot(column=color_by_obs, cmap=custom_cmap, ax=axes[1], legend=True)
axes[1].set_title(color_by_obs)
legend = axes[1].get_legend()
legend.set_bbox_to_anchor((1.05, 1))
axes[1].axis('off')
# Move legend outside the plot
plot.get_legend().set_bbox_to_anchor((1.25, 1))
if output_name is not None:
plt.savefig(output_name, bbox_inches='tight')
else:
plt.show()
# Plotting function for nuclei area distribution
def plot_nuclei_area(gdf,area_cut_off,output_name):
fig, axs = plt.subplots(1, 2, figsize=(15, 4))
# Plot the histograms
axs[0].hist(gdf['area'], bins=50, edgecolor='black')
axs[0].set_title('Nuclei Area')
axs[1].hist(gdf[gdf['area'] < area_cut_off]['area'], bins=50, edgecolor='black')
axs[1].set_title('Nuclei Area Filtered:'+str(area_cut_off))
plt.tight_layout()
# Save the plot if output_name is provided
if output_name is not None:
plt.savefig(output_name, bbox_inches='tight') # Use bbox_inches='tight' to include the legend
else:
plt.show()
# Total UMI distribution plotting function
def total_umi(adata_, cut_off,output_name):
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
# Box plot
axs[0].boxplot(adata_.obs["total_counts"], vert=False, widths=0.7, patch_artist=True, boxprops=dict(facecolor='skyblue'))
axs[0].set_title('Total Counts')
# Box plot after filtering
axs[1].boxplot(adata_.obs["total_counts"][adata_.obs["total_counts"] > cut_off], vert=False, widths=0.7, patch_artist=True, boxprops=dict(facecolor='skyblue'))
axs[1].set_title('Total Counts > ' + str(cut_off))
# Remove y-axis ticks and labels
for ax in axs:
ax.get_yaxis().set_visible(False)
plt.tight_layout()
# Save the plot if output_name is provided
if output_name is not None:
plt.savefig(output_name, bbox_inches='tight') # Use bbox_inches='tight' to include the legend
else:
plt.show()
parser = argparse.ArgumentParser()
################## Mandatory ####################################################################
#parser.add_argument( '--data_name', type=str, help='Name of the dataset', default="Visium_HD_Human_Colon_Cancer_square_002um_outputs_full")
parser.add_argument( '--data_from', type=str, default='/cluster/projects/schwartzgroup/fatema/data/Visium_HD_Human_Colon_Cancer_square_002um_outputs/', help='Path to the dataset to read from. Space Ranger outs/ folder is preferred. Otherwise, provide the *.mtx file of the gene expression matrix.')
parser.add_argument( '--file_name', type=str, help='Name of the btf file', default='Visium_HD_Human_Colon_Cancer_tissue_image.btf')
parser.add_argument( '--data_to', type=str, default='/cluster/projects/schwartzgroup/fatema/NEST/data/segmented_visium_hd_CRC/', help='Path to save the input graph (to be passed to GAT)')
parser.add_argument( '--prob_thresh', type=float, default=0.75, help='Probability threshold for the StarDist segmentation algorithm')
parser.add_argument( '--mask_area', type=int, default=1000, help='Parameter for post segmentation filtering: Nuclei area size')
parser.add_argument( '--mask_count', type=int, default=100, help='Parameter for post segmentation filtering: UMI count')
#prob_thresh=0.75
args = parser.parse_args()
if not os.path.exists(args.data_to):
os.makedirs(args.data_to)
# Load the image file
# Change dir_base as needed to the directory where the downloaded example data is stored
dir_base = args.data_from
filename = args.file_name
print('Img loading ...')
img = imread(dir_base + filename)
# Load the pretrained model
print('StarDist loading ...')
model = StarDist2D.from_pretrained('2D_versatile_he')
# Percentile normalization of the image
# Adjust min_percentile and max_percentile as needed
min_percentile = 5
max_percentile = 95
print('Normalizing the Img')
img = normalize(img, min_percentile, max_percentile)
# Predict cell nuclei using the normalized image
# Adjust nms_thresh and prob_thresh as needed
# prob_thresh=0.75
print('Segmentation running ...with prob_thresh=%g'%args.prob_thresh)
labels, polys = model.predict_instances_big(img, axes='YXC', block_size=4096, prob_thresh=args.prob_thresh, nms_thresh=0.001, min_overlap=128, context=128, normalizer=None, n_tiles=(4,4,1))
# Creating a list to store Polygon geometries
geometries = []
# Iterating through each nuclei in the 'polys' DataFrame
for nuclei in range(len(polys['coord'])):
# Extracting coordinates for the current nuclei and converting them to (y, x) format
coords = [(y, x) for x, y in zip(polys['coord'][nuclei][0], polys['coord'][nuclei][1])]
# Creating a Polygon geometry from the coordinates
geometries.append(Polygon(coords))
# Creating a GeoDataFrame using the Polygon geometries
gdf = gpd.GeoDataFrame(geometry=geometries)
gdf['id'] = [f"ID_{i+1}" for i, _ in enumerate(gdf.index)]
# Define a single color cmap
cmap=ListedColormap(['grey'])
# Load Visium HD data
print('filtered_feature_bc_matrix.h5 reading')
raw_h5_file = dir_base+'filtered_feature_bc_matrix.h5'
###### save the barcode index and coordinates for later use ##############################################
adata = sc.read_visium(path=dir_base, count_file='filtered_feature_bc_matrix.h5')
barcode_list = list(adata.obs.index)
barcode_coord = dict()
for i in range(0, len(barcode_list)):
barcode_coord[barcode_list[i]] = [adata.obsm['spatial'][i][0], adata.obsm['spatial'][i][1]]
#############################################################################################################
adata = sc.read_10x_h5(raw_h5_file)
# Load the Spatial Coordinates
print('Load the spatial coordinates parquet file')
tissue_position_file = dir_base+'spatial/'+'tissue_positions.parquet'
df_tissue_positions=pd.read_parquet(tissue_position_file)
#Set the index of the dataframe to the barcodes
df_tissue_positions = df_tissue_positions.set_index('barcode')
# Create an index in the dataframe to check joins
df_tissue_positions['index']=df_tissue_positions.index
# Adding the tissue positions to the meta data
adata.obs = pd.merge(adata.obs, df_tissue_positions, left_index=True, right_index=True)
# Create a GeoDataFrame from the DataFrame of coordinates
geometry = [Point(xy) for xy in zip(df_tissue_positions['pxl_col_in_fullres'], df_tissue_positions['pxl_row_in_fullres'])]
gdf_coordinates = gpd.GeoDataFrame(df_tissue_positions, geometry=geometry)
# Perform a spatial join to check which coordinates are in a cell nucleus
result_spatial_join = gpd.sjoin(gdf_coordinates, gdf, how='left', predicate='within')
# Identify nuclei associated barcodes and find barcodes that are in more than one nucleus
result_spatial_join['is_within_polygon'] = ~result_spatial_join['index_right'].isna()
barcodes_in_overlaping_polygons = pd.unique(result_spatial_join[result_spatial_join.duplicated(subset=['index'])]['index'])
result_spatial_join['is_not_in_an_polygon_overlap'] = ~result_spatial_join['index'].isin(barcodes_in_overlaping_polygons)
# Remove barcodes in overlapping nuclei
barcodes_in_one_polygon = result_spatial_join[result_spatial_join['is_within_polygon'] & result_spatial_join['is_not_in_an_polygon_overlap']]
# The AnnData object is filtered to only contain the barcodes that are in non-overlapping polygon regions
filtered_obs_mask = adata.obs_names.isin(barcodes_in_one_polygon['index'])
filtered_adata = adata[filtered_obs_mask,:]
# Add the results of the point spatial join to the Anndata object
filtered_adata.obs = pd.merge(filtered_adata.obs, barcodes_in_one_polygon[['index','geometry','id','is_within_polygon','is_not_in_an_polygon_overlap']], left_index=True, right_index=True)
# Group the data by unique nucleous IDs
groupby_object = filtered_adata.obs.groupby(['id'], observed=True)
# Extract the gene expression counts from the AnnData object
counts = filtered_adata.X
# Obtain the number of unique nuclei and the number of genes in the expression data
N_groups = groupby_object.ngroups
N_genes = counts.shape[1]
# Initialize a sparse matrix to store the summed gene counts for each nucleus
summed_counts = sparse.lil_matrix((N_groups, N_genes))
# Lists to store the IDs of polygons and the current row index
polygon_id = []
row = 0
# Iterate over each unique polygon to calculate the sum of gene counts.
for polygons, idx_ in groupby_object.indices.items():
summed_counts[row] = counts[idx_].sum(0)
row += 1
polygon_id.append(polygons)
# Create and AnnData object from the summed count matrix
summed_counts = summed_counts.tocsr()
grouped_filtered_adata = anndata.AnnData(X=summed_counts,obs=pd.DataFrame(polygon_id,columns=['id'],index=polygon_id),var=filtered_adata.var)
# Store the area of each nucleus in the GeoDataframe
gdf['area'] = gdf['geometry'].area
# Calculate quality control metrics for the original AnnData object
sc.pp.calculate_qc_metrics(grouped_filtered_adata, inplace=True)
########### IMP for parameter setting for mask_area and mask_count #############
# Plot the nuclei area distribution before and after filtering
# plot_nuclei_area(gdf=gdf,area_cut_off=1000,output_name=dir_base+"image_nuclei_area.tif")
# Plot total UMI distribution
# total_umi(grouped_filtered_adata, 100,output_name=dir_base+"image_umi.tif")
print('Filtering based on predetermined mask_area and mask_count')
# Create a mask based on the 'id' column for values present in 'gdf' with 'area' less than 1000
mask_area = grouped_filtered_adata.obs['id'].isin(gdf[gdf['area'] < args.mask_area].id)
# Create a mask based on the 'total_counts' column for values greater than 100
mask_count = grouped_filtered_adata.obs['total_counts'] > args.mask_count
# Apply both masks to the original AnnData to create a new filtered AnnData object
count_area_filtered_adata = grouped_filtered_adata[mask_area & mask_count, :]
# Calculate quality control metrics for the filtered AnnData object
sc.pp.calculate_qc_metrics(count_area_filtered_adata, inplace=True)
print('Writing: '+args.data_to+'count_area_filtered_adata_p75.h5ad')
filename=args.data_to+'count_area_filtered_adata_p75.h5ad'
count_area_filtered_adata.write_h5ad(filename=filename, compression='gzip')
print('Write done')
################ now retrieve the coordinates by intersecting the original anndata with the segmented one ######################
cell_id = np.array(count_area_filtered_adata.obs.index)
print('Retrieving the coordinates for the segmented cells ...')
barcode_vs_id = pd.DataFrame(filtered_adata.obs['id'])
# following give barcode and associated coordinates in adata.obs.index and adata.obsm['spatial'] respectively
# barcode_list, barcode_coord
# following gives barcode vs id for segmented+grouped data p75:
# barcode_vs_id
# combine above two to get: id = list of (barcodes, coordinates) assigned to that id
id_barcode_coord = defaultdict(list) # key=id, value=[[barcode, [coord]]]
for i in range(0, len(barcode_vs_id)):
id_barcode_coord[barcode_vs_id.id[i]].append([barcode_vs_id.index[i], barcode_coord[barcode_vs_id.index[i]]])
# filter it to keep only those who are in the final area+UMI filtered data
id_barcode_coord_temp = defaultdict(list) # key=id, value=[[barcode, [coord]]]
for i in range(0, len(cell_id)):
id_barcode_coord_temp[cell_id[i]] = id_barcode_coord[cell_id[i]]
id_barcode_coord = id_barcode_coord_temp
# intersect barcode_id_coord with adata_h5.obs['id'] --> to get coordinates of cells in adata_h5
coordinates = np.zeros((cell_id.shape[0], 2)) # insert the coordinates in the order of cell_barcodes
cell_barcode = [] # list of raw barcodes for each polygon
for i in range (0, cell_id.shape[0]):
list_barcodes_coord = id_barcode_coord[cell_id[i]]
cell_barcode.append([])
list_coords = []
for j in range (0, len(list_barcodes_coord)):
cell_barcode[i].append(list_barcodes_coord[j][0])
list_coords.append((list_barcodes_coord[j][1]))
point = MultiPoint(list_coords)
coordinates[i,0] = point.centroid.coords[0][0]
coordinates[i,1] = point.centroid.coords[0][1]
with gzip.open(args.data_to + 'coordinate_barcode', 'wb') as fp:
pickle.dump([coordinates, cell_barcode], fp)
print('Coordinate generation done: '+args.data_to + 'coordinate_barcode')
############################ Now plot it to see how does it look ###################
data_list=dict()
data_list['X']=[]
data_list['Y']=[]
for i in range (0, coordinates.shape[0]):
data_list['X'].append(coordinates[i, 0])
data_list['Y'].append(coordinates[i, 1])
data_list_pd = pd.DataFrame(data_list)
chart = alt.Chart(data_list_pd).mark_point(filled=True, opacity = 1).encode(
alt.X('X', scale=alt.Scale(zero=False)),
alt.Y('Y', scale=alt.Scale(zero=False)),
)
chart.save(args.data_to + 'tissue_altair_plot.html')
print('Altair plot generation done: '+args.data_to + 'tissue_altair_plot.html')