-
Notifications
You must be signed in to change notification settings - Fork 11
/
run-cellpose-qc.py
361 lines (295 loc) · 14.9 KB
/
run-cellpose-qc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import argparse
import csv
import numpy as np
from collections import namedtuple
from pathlib import Path
from skimage import io
from skimage.segmentation import relabel_sequential
from scipy.optimize import linear_sum_assignment
# Parse input arguments
parser = argparse.ArgumentParser(description='Get Quality Metrics from Cellpose Run. Needs a folder with Raw Data, Ground Truth and Cellpose Labels')
parser.add_argument('dir', nargs=1, help='a directory containing the resulting images for QC (Raw, GT and Model Results)')
parser.add_argument('model', nargs=1, help='The name of the model being tested')
args = parser.parse_args()
data_folder = args.dir[0]
print (data_folder)
model_name = args.model[0]
print(model_name)
matching_criteria = dict()
## All functions to do the work
## Actual run is all the way down
def label_are_sequential(y):
""" returns true if y has only sequential labels from 1... """
labels = np.unique(y)
return (set(labels)-{0}) == set(range(1,1+labels.max()))
def is_array_of_integers(y):
return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)
def _check_label_array(y, name=None, check_sequential=False):
err = ValueError("{label} must be an array of {integers}.".format(
label = 'labels' if name is None else name,
integers = ('sequential ' if check_sequential else '') + 'non-negative integers',
))
is_array_of_integers(y) or print("An error occured")
if check_sequential:
label_are_sequential(y) or print("An error occured")
else:
y.min() >= 0 or print("An error occured")
return True
def label_overlap(x, y, check=True):
if check:
_check_label_array(x,'x',True)
_check_label_array(y,'y',True)
x.shape == y.shape or _raise(ValueError("x and y must have the same shape"))
return _label_overlap(x, y)
def _label_overlap(x, y):
x = x.ravel()
y = y.ravel()
overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
for i in range(len(x)):
overlap[x[i],y[i]] += 1
return overlap
def intersection_over_union(overlap):
_check_label_array(overlap,'overlap')
if np.sum(overlap) == 0:
return overlap
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
return overlap / (n_pixels_pred + n_pixels_true - overlap)
matching_criteria['iou'] = intersection_over_union
def intersection_over_true(overlap):
_check_label_array(overlap,'overlap')
if np.sum(overlap) == 0:
return overlap
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
return overlap / n_pixels_true
matching_criteria['iot'] = intersection_over_true
def intersection_over_pred(overlap):
_check_label_array(overlap,'overlap')
if np.sum(overlap) == 0:
return overlap
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
return overlap / n_pixels_pred
matching_criteria['iop'] = intersection_over_pred
def precision(tp,fp,fn):
return tp/(tp+fp) if tp > 0 else 0
def recall(tp,fp,fn):
return tp/(tp+fn) if tp > 0 else 0
def accuracy(tp,fp,fn):
# also known as "average precision" (?)
# -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation
return tp/(tp+fp+fn) if tp > 0 else 0
def f1(tp,fp,fn):
# also known as "dice coefficient"
return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0
def _safe_divide(x,y):
return x/y if y>0 else 0.0
def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):
"""Calculate detection/instance segmentation metrics between ground truth and predicted label images.
Currently, the following metrics are implemented:
'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'
Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)
whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)
* mean_matched_score is the mean IoUs of matched true positives
* mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects
* panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019
Parameters
----------
y_true: ndarray
ground truth label image (integer valued)
predicted label image (integer valued)
thresh: float
threshold for matching criterion (default 0.5)
criterion: string
matching criterion (default IoU)
report_matches: bool
if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')
Returns
-------
Matching object with different metrics as attributes
Examples
--------
>>> y_true = np.zeros((100,100), np.uint16)
>>> y_true[10:20,10:20] = 1
>>> y_pred = np.roll(y_true,5,axis = 0)
>>> stats = matching(y_true, y_pred)
>>> print(stats)
Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)
"""
_check_label_array(y_true,'y_true')
_check_label_array(y_pred,'y_pred')
y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred)))
criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion))
if thresh is None: thresh = 0
thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)
y_true, _, map_rev_true = relabel_sequential(y_true)
y_pred, _, map_rev_pred = relabel_sequential(y_pred)
overlap = label_overlap(y_true, y_pred, check=False)
scores = matching_criteria[criterion](overlap)
assert 0 <= np.min(scores) <= np.max(scores) <= 1
# ignoring background
scores = scores[1:,1:]
n_true, n_pred = scores.shape
n_matched = min(n_true, n_pred)
def _single(thr):
not_trivial = n_matched > 0 and np.any(scores >= thr)
if not_trivial:
# compute optimal matching with scores as tie-breaker
costs = -(scores >= thr).astype(float) - scores / (2*n_matched)
true_ind, pred_ind = linear_sum_assignment(costs)
assert n_matched == len(true_ind) == len(pred_ind)
match_ok = scores[true_ind,pred_ind] >= thr
tp = np.count_nonzero(match_ok)
else:
tp = 0
fp = n_pred - tp
fn = n_true - tp
# assert tp+fp == n_pred
# assert tp+fn == n_true
# the score sum over all matched objects (tp)
sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0
# the score average over all matched objects (tp)
mean_matched_score = _safe_divide(sum_matched_score, tp)
# the score average over all gt/true objects
mean_true_score = _safe_divide(sum_matched_score, n_true)
panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
stats_dict = dict (
criterion = criterion,
thresh = thr,
fp = fp,
tp = tp,
fn = fn,
precision = precision(tp,fp,fn),
recall = recall(tp,fp,fn),
accuracy = accuracy(tp,fp,fn),
f1 = f1(tp,fp,fn),
n_true = n_true,
n_pred = n_pred,
mean_true_score = mean_true_score,
mean_matched_score = mean_matched_score,
panoptic_quality = panoptic_quality,
)
if bool(report_matches):
if not_trivial:
stats_dict.update (
# int() to be json serializable
matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),
matched_scores = tuple(scores[true_ind,pred_ind]),
matched_tps = tuple(map(int,np.flatnonzero(match_ok))),
)
else:
stats_dict.update (
matched_pairs = (),
matched_scores = (),
matched_tps = (),
)
return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())
return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))
def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
"""matching metrics for list of images, see `stardist.matching.matching`
"""
len(y_true) == len(y_pred) or _raise(ValueError("y_true and y_pred must have the same length."))
return matching_dataset_lazy (
tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,
)
def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))
single_thresh = False
if np.isscalar(thresh):
single_thresh = True
thresh = (thresh,)
tqdm_kwargs = {}
tqdm_kwargs['disable'] = not bool(show_progress)
if int(show_progress) > 1:
tqdm_kwargs['total'] = int(show_progress)
# compute matching stats for every pair of label images
if parallel:
from concurrent.futures import ThreadPoolExecutor
fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)
with ThreadPoolExecutor() as pool:
stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))
else:
stats_all = tuple (
matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)
for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)
)
# accumulate results over all images for each threshold separately
n_images, n_threshs = len(stats_all), len(thresh)
accumulate = [{} for _ in range(n_threshs)]
for stats in stats_all:
for i,s in enumerate(stats):
acc = accumulate[i]
for k,v in s._asdict().items():
if k == 'mean_true_score' and not bool(by_image):
# convert mean_true_score to "sum_matched_score"
acc[k] = acc.setdefault(k,0) + v * s.n_true
else:
try:
acc[k] = acc.setdefault(k,0) + v
except TypeError:
pass
# normalize/compute 'precision', 'recall', 'accuracy', 'f1'
for thr,acc in zip(thresh,accumulate):
set(acc.keys()) == expected_keys or _raise(ValueError("unexpected keys"))
acc['criterion'] = criterion
acc['thresh'] = thr
acc['by_image'] = bool(by_image)
if bool(by_image):
for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
acc[k] /= n_images
else:
tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']
sum_matched_score = acc['mean_true_score']
mean_matched_score = _safe_divide(sum_matched_score, tp)
mean_true_score = _safe_divide(sum_matched_score, n_true)
panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
acc.update(
precision = precision(tp,fp,fn),
recall = recall(tp,fp,fn),
accuracy = accuracy(tp,fp,fn),
f1 = f1(tp,fp,fn),
mean_true_score = mean_true_score,
mean_matched_score = mean_matched_score,
panoptic_quality = panoptic_quality,
)
accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)
return accumulate[0] if single_thresh else accumulate
# Here we start testing the differences between GT and predicted label images
def compareLabels( model_name, image_folder ):
image_folder = Path(image_folder)
# Grab all tif images
raw_images = [x for x in Path(image_folder).glob("*.tif") if not ( "masks" in x.name or "flows" in x.name )]
# Define save folder from the parent of the predicted_labels_folder
results_path = image_folder / "QC-Results"
# Make the directory if it's missing
results_path.absolute().mkdir( exist_ok=True )
with open(results_path / ( "Quality_Control for "+model_name+".csv" ), "w", newline='') as file:
writer = csv.writer(file, delimiter=",")
writer.writerow(["model","image","Prediction v. GT Intersection over Union", "false positive", "true positive", "false negative", "precision", "recall", "accuracy", "f1 score", "n_true", "n_pred", "mean_true_score", "mean_matched_score", "panoptic_quality"])
# define the images
for raw in raw_images:
mask_name = f"{raw.stem}_masks{raw.suffix}"
gt_image = raw.parent / mask_name
cp_mask_name = f"{raw.stem}_cp_masks{raw.suffix}"
pred_image = raw.parent / cp_mask_name
if not gt_image.is_file():
raise FileNotFoundError
if not pred_image.is_file():
raise FileNotFoundError
#test_input = io.imread(raw)
test_prediction = io.imread(pred_image)
test_ground_truth_image = io.imread(gt_image)
# Calculate the matching (with IoU threshold `thresh`) and all metrics
stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)
#Convert pixel values to 0 or 255
test_prediction_0_to_255 = test_prediction
test_prediction_0_to_255[test_prediction_0_to_255>0] = 255
#Convert pixel values to 0 or 255
test_ground_truth_0_to_255 = test_ground_truth_image
test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255
# Intersection over Union metric
intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)
union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)
iou_score = np.sum(intersection) / np.sum(union)
writer.writerow([model_name, raw.name, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])
# Finally running the check on the given inputs
compareLabels(model_name, data_folder)