diff --git a/docs/source/user_guide/evaluation.rst b/docs/source/user_guide/evaluation.rst index 57e25d5b11..6edc1d38c5 100644 --- a/docs/source/user_guide/evaluation.rst +++ b/docs/source/user_guide/evaluation.rst @@ -729,8 +729,9 @@ The only difference between each task type is in how the IoU between objects is calculated: - For object detections, IoUs are computed between each pair of bounding boxes -- For instance segmentations and polygons, IoUs are computed between the - polygonal shapes rather than their rectangular bounding boxes +- For instance segmentations, when ``use_masks=True``, IoUs are computed + between the dense pixel masks rather than their rectangular bounding boxes +- For polygons, IoUs are computed between the polygonal shapes - For keypoint tasks, `object keypoint similarity `_ is computed for each pair of objects, using the extent of the ground truth @@ -744,8 +745,7 @@ stored in |Detections| format. For instance segmentation tasks, the ground truth and predicted objects should be stored in |Detections| format, and each |Detection| instance should have its -:attr:`mask ` attribute populated to -define the extent of the object within its bounding box. +mask populated to define the extent of the object within its bounding box. .. note:: diff --git a/fiftyone/utils/eval/coco.py b/fiftyone/utils/eval/coco.py index 1efa89331b..a82445d59a 100644 --- a/fiftyone/utils/eval/coco.py +++ b/fiftyone/utils/eval/coco.py @@ -48,7 +48,8 @@ class COCOEvaluationConfig(DetectionEvaluationConfig): of the provided :class:`fiftyone.core.labels.Polyline` instances rather than using their actual geometries tolerance (None): a tolerance, in pixels, when generating approximate - polylines for instance masks. Typical values are 1-3 pixels + polylines for instance masks. Typical values are 1-3 pixels. By + default, IoUs are computed directly on the dense pixel masks compute_mAP (False): whether to perform the necessary computations so that mAP and PR curves can be generated iou_threshs (None): a list of IoU thresholds to use when computing mAP diff --git a/fiftyone/utils/eval/openimages.py b/fiftyone/utils/eval/openimages.py index ffd4d65e20..1d14aafdaa 100644 --- a/fiftyone/utils/eval/openimages.py +++ b/fiftyone/utils/eval/openimages.py @@ -43,7 +43,8 @@ class OpenImagesEvaluationConfig(DetectionEvaluationConfig): of the provided :class:`fiftyone.core.labels.Polyline` instances rather than using their actual geometries tolerance (None): a tolerance, in pixels, when generating approximate - polylines for instance masks. Typical values are 1-3 pixels + polylines for instance masks. Typical values are 1-3 pixels. By + default, IoUs are computed directly on the dense pixel masks max_preds (None): the maximum number of predicted objects to evaluate when computing mAP and PR curves error_level (1): the error level to use when manipulating instance diff --git a/fiftyone/utils/iou.py b/fiftyone/utils/iou.py index 65d32229c0..077e789309 100644 --- a/fiftyone/utils/iou.py +++ b/fiftyone/utils/iou.py @@ -16,6 +16,7 @@ import eta.core.numutils as etan import eta.core.utils as etau +import eta.core.image as etai import fiftyone.core.labels as fol import fiftyone.core.utils as fou @@ -73,7 +74,8 @@ def compute_ious( of the provided :class:`fiftyone.core.labels.Polyline` instances rather than using their actual geometries tolerance (None): a tolerance, in pixels, when generating approximate - polylines for instance masks. Typical values are 1-3 pixels + polylines for instance masks. Typical values are 1-3 pixels. By + default, IoUs are computed directly on the dense pixel masks sparse (False): whether to return a sparse dict of non-zero IoUs rather than a full matrix error_level (1): the error level to use when manipulating instance @@ -136,11 +138,6 @@ def compute_ious( ) if use_masks: - # @todo when tolerance is None, consider using dense masks rather than - # polygonal approximations? - if tolerance is None: - tolerance = 2 - return _compute_mask_ious( preds, gts, @@ -528,6 +525,65 @@ def compute_bbox_iou(gt, pred, gt_crowd=False): return min(etan.safe_divide(inter, union), 1) +def _dense_iou(gt, pred, gt_crowd=False): + """Computes the IoU between the given ground truth and predicted + detection masks. + + Args: + gt: a :class:`fiftyone.core.labels.Detection` + pred: a :class:`fiftyone.core.labels.Detection` + gt_crowd (False): whether the ground truth object is a crowd + + Returns: + the IoU, in ``[0, 1]`` + """ + gt_mask = gt.mask + gt_bb = gt.bounding_box # x,y,w,h of box + gt_mask_h, gt_mask_w = gt_mask.shape + + pred_mask = pred.mask + pred_bb = pred.bounding_box # x,y,w,h of box + pred_mask_h, pred_mask_w = pred_mask.shape + + gt_img_w = round(gt_mask_w / gt_bb[2]) + gt_img_h = round(gt_mask_h / gt_bb[3]) + + pred_img_w = round(pred_mask_w / pred_bb[2]) + pred_img_h = round(pred_mask_h / pred_bb[3]) + + gt_mask_full = np.zeros((gt_img_h, gt_img_w)) + pred_mask_full = np.zeros((pred_img_h, pred_img_w)) + + x1, y1, x2, y2 = _float_to_pixel(gt_bb, gt_img_w, gt_img_h) + gt_mask_full[y1:y2, x1:x2] = gt_mask + + x1, y1, x2, y2 = _float_to_pixel(pred_bb, pred_img_w, pred_img_h) + pred_mask_full[y1:y2, x1:x2] = pred_mask + + if gt_img_w != pred_img_w or gt_img_h != pred_img_h: + gt_size = gt_img_w * gt_img_h + pred_size = pred_img_w * pred_img_h + if gt_size > pred_size: + pred_mask_full = etai.resize( + pred_mask_full, + height=gt_img_h, + width=gt_img_w, + interpolation=0, # equivalent to cv2.INTER_NEAREST + ) + else: + gt_mask_full = etai.resize( + gt_mask_full, + height=pred_img_h, + width=pred_img_w, + interpolation=0, # equivalent to cv2.INTER_NEAREST + ) + + inter = np.logical_and(gt_mask_full, pred_mask_full).sum() + union = np.logical_or(gt_mask_full, pred_mask_full).sum() + + return min(etan.safe_divide(inter, union), 1) + + def _get_detection_box(det, dimension=None): if dimension is None: dimension = _get_bbox_dim(det) @@ -559,6 +615,10 @@ def _get_poly_box(x): return _get_detection_box(detection) +def _get_mask_box(x): + return _get_detection_box(x) + + def _compute_bbox_ious( preds, gts, @@ -600,6 +660,7 @@ def _compute_bbox_ious( for i, pred in enumerate(preds): box = _get_detection_box(pred, dimension=index_property.dimension) + # pylint: disable=no-value-for-parameter indices = rtree_index.intersection(box) for j in indices: # pylint: disable=not-an-iterable gt = gts[j] @@ -624,6 +685,61 @@ def _compute_bbox_ious( return ious +def _compute_dense_mask_ious( + preds, + gts, + error_level, + iscrowd=None, + classwise=False, + sparse=False, +): + is_symmetric = preds is gts + + if sparse: + ious = defaultdict(list) + else: + ious = np.zeros((len(preds), len(gts))) + + if iscrowd is not None: + gt_crowds = [iscrowd(gt) for gt in gts] + else: + gt_crowds = [False] * len(gts) + + index_property = rti.Property() + bbox_iou_fcn = compute_bbox_iou + index_property.dimension = 2 + + rtree_index = rti.Index(properties=index_property, interleaved=False) + for i, gt in enumerate(gts): + box = _get_mask_box(gt) + rtree_index.insert(i, box) + + for i, pred in enumerate(preds): + box = _get_mask_box(pred) + # pylint: disable=no-value-for-parameter + indices = rtree_index.intersection(box) + for j in indices: # pylint: disable=not-an-iterable + gt = gts[j] + gt_crowd = gt_crowds[j] + if classwise and pred.label != gt.label: + continue + + if is_symmetric and j > i: + continue + + iou = _dense_iou(gt, pred, gt_crowd=gt_crowd) + + if sparse: + ious[pred.id].append((gt.id, iou)) + if is_symmetric: + ious[gt.id].append((pred.id, iou)) + else: + ious[i, j] = iou + if is_symmetric: + ious[j, i] = iou + return ious + + def _compute_polygon_ious( preds, gts, @@ -675,6 +791,7 @@ def _compute_polygon_ious( zip(preds, pred_polys, pred_labels, pred_areas) ): box = _get_poly_box(pred) + # pylint: disable=no-value-for-parameter indices = rtree_index.intersection(box) for j in indices: # pylint: disable=not-an-iterable gt = gts[j] @@ -767,34 +884,43 @@ def _compute_mask_ious( ): is_symmetric = preds is gts - with contextlib.ExitStack() as context: - # We're ignoring errors, so suppress shapely logging that occurs when - # invalid geometries are encountered - if error_level > 1: - context.enter_context( - fou.LoggingLevel(logging.CRITICAL, logger="shapely") - ) - - pred_polys = _masks_to_polylines(preds, tolerance, error_level) - - if is_symmetric: - gt_polys = pred_polys - else: - gt_polys = _masks_to_polylines(gts, tolerance, error_level) - if iscrowd is not None: gt_crowds = [iscrowd(gt) for gt in gts] else: gt_crowds = [False] * len(gts) - return _compute_polygon_ious( - pred_polys, - gt_polys, - error_level, - classwise=classwise, - gt_crowds=gt_crowds, - sparse=sparse, - ) + if tolerance is None: + return _compute_dense_mask_ious( + preds, + gts, + error_level, + classwise=classwise, + sparse=sparse, + ) + else: + with contextlib.ExitStack() as context: + # We're ignoring errors, so suppress shapely logging that occurs when + # invalid geometries are encountered + if error_level > 1: + context.enter_context( + fou.LoggingLevel(logging.CRITICAL, logger="shapely") + ) + + pred_polys = _masks_to_polylines(preds, tolerance, error_level) + + if is_symmetric: + gt_polys = pred_polys + else: + gt_polys = _masks_to_polylines(gts, tolerance, error_level) + + return _compute_polygon_ious( + pred_polys, + gt_polys, + error_level, + classwise=classwise, + gt_crowds=gt_crowds, + sparse=sparse, + ) def _compute_segment_ious(preds, gts, sparse=False): @@ -910,6 +1036,14 @@ def _compute_object_keypoint_similarity(gtp, predp): return np.sum(np.exp(-(dists**2) / (2 * (scale**2)))) / n +def _float_to_pixel(gt_bb, img_w, img_h): + x1 = round(gt_bb[0] * img_w) + y1 = round(gt_bb[1] * img_h) + x2 = round(x1 + (gt_bb[2] * img_w)) + y2 = round(y1 + (gt_bb[3] * img_h)) + return x1, y1, x2, y2 + + def _polylines_to_detections(polylines): detections = [] for polyline in polylines: