Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add DOTAMeanAP metric #65

Merged
merged 46 commits into from
Jan 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
101e4e5
add dota_map
YanxingLiu Dec 11, 2022
e713bf0
add dota_map docstrings
YanxingLiu Dec 11, 2022
1c81d19
modify some docstrings
YanxingLiu Dec 14, 2022
f952b83
modify some docstrings
YanxingLiu Dec 14, 2022
3159446
add rotated iou calculation with mmcv backend
YanxingLiu Dec 14, 2022
b5e382e
modify some function names
YanxingLiu Dec 15, 2022
c0c5714
Update mmeval/metrics/dota_map.py
YanxingLiu Dec 20, 2022
cbbf225
implement filter_by_bboxes_area as a class method
YanxingLiu Dec 20, 2022
415f801
Merge branch 'AddDOTAMetric' of github.com:YanxingLiu/mmeval into Add…
YanxingLiu Dec 20, 2022
137cebd
implement filter_by_bboxes_area as a class method
YanxingLiu Dec 20, 2022
e803ef9
Update mmeval/metrics/dota_map.py
YanxingLiu Dec 20, 2022
2c1c175
modify function name from filter_by_bboxes_area to _filter_by_bboxes_…
YanxingLiu Dec 20, 2022
6cbc2e3
fix a bug thta occurs when mmcv is installed
YanxingLiu Dec 24, 2022
a78c9f8
add qbox support
YanxingLiu Dec 27, 2022
ea1939a
modify docstrings for quadrilateral boxes support
YanxingLiu Dec 28, 2022
e828618
Apply suggestions from code review
ice-tong Jan 4, 2023
f8a663d
fix lint
ice-tong Jan 4, 2023
ce58a23
Update mmeval/metrics/dota_map.py
YanxingLiu Jan 10, 2023
a7cd54f
Update mmeval/metrics/dota_map.py
YanxingLiu Jan 10, 2023
c267631
Update mmeval/metrics/voc_map.py
YanxingLiu Jan 10, 2023
c68c76f
modify DOTAMeanAP docstrings
YanxingLiu Jan 10, 2023
1e35757
Merge branch 'AddDOTAMetric' of github.com:YanxingLiu/mmeval into Add…
YanxingLiu Jan 10, 2023
0db424a
modify VOCMeanAP docstrings
YanxingLiu Jan 10, 2023
7fbd3ef
Merge branch 'main' into AddDOTAMetric
YanxingLiu Jan 11, 2023
3582b99
add DOTAMeanAP to metrics.rst
YanxingLiu Jan 11, 2023
3073d20
merge from upstream
YanxingLiu Jan 11, 2023
e695933
Update mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 16, 2023
961196f
add docstring in mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 16, 2023
56901cf
Merge branch 'AddDOTAMetric' of github.com:YanxingLiu/mmeval into Add…
YanxingLiu Jan 16, 2023
239649d
add some test cases and some assertion
YanxingLiu Jan 16, 2023
aa8be5a
Update mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 16, 2023
4f6dc0f
Update mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 16, 2023
8b231ea
Update mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 19, 2023
7b045b7
Update mmeval/metrics/dota_map.py
YanxingLiu Jan 19, 2023
8efb43b
add opencv-python in requirements/runtime.txt
YanxingLiu Jan 19, 2023
39b6aa6
fix: use try_import to import cv2
ice-tong Jan 19, 2023
5cac2ae
Merge branch 'main' into AddDOTAMetric
ice-tong Jan 19, 2023
136bf32
Update mmeval/metrics/dota_map.py
YanxingLiu Jan 25, 2023
f629d3f
Update mmeval/metrics/dota_map.py
YanxingLiu Jan 25, 2023
ee381b3
fix a bug caused by static function
YanxingLiu Jan 25, 2023
db2e60b
modify docstring of filter_by_bboxes_area_rotated
YanxingLiu Jan 25, 2023
11dfb09
Update mmeval/metrics/voc_map.py
YanxingLiu Jan 25, 2023
8a4c255
Update mmeval/metrics/dota_map.py
YanxingLiu Jan 25, 2023
1f2973a
Update mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 25, 2023
eff2898
Update mmeval/metrics/utils/bbox_overlaps_rotated.py
YanxingLiu Jan 25, 2023
a3d46b1
add test_metric_accurate function in test_dota_map.py
YanxingLiu Jan 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .accuracy import Accuracy
from .ava_map import AVAMeanAP
from .coco_detection import COCODetectionMetric
from .dota_map import DOTAMeanAP
from .end_point_error import EndPointError
from .f_metric import F1Metric
from .hmean_iou import HmeanIoU
Expand All @@ -24,5 +25,5 @@
'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric',
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP'
'AveragePrecision', 'AVAMeanAP', 'DOTAMeanAP'
]
274 changes: 274 additions & 0 deletions mmeval/metrics/dota_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from multiprocessing.pool import Pool
from typing import Dict, List, Optional, Sequence, Tuple

from mmeval.metrics.voc_map import VOCMeanAP
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
from .utils.bbox_iou_rotated import (bbox_iou_rotated,
calculate_bboxes_area_rotated)


def filter_by_bboxes_area_rotated(bboxes: np.ndarray,
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
min_area: Optional[float],
max_area: Optional[float]):
"""Filter the rotated bboxes with an area range.

Args:
bboxes (numpy.ndarray): The bboxes with shape (n, 5) in 'xywha' format.
min_area (Optional[float]): The minimum area. If None, does not filter
the minimum area.
max_area (Optional[float]): The maximum area. If None, does not filter
the maximum area.
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
Returns:
numpy.ndarray: A mask of ``bboxes`` identify which bbox are filtered.
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
"""
bboxes_area = calculate_bboxes_area_rotated(bboxes)
area_mask = np.ones_like(bboxes_area, dtype=bool)
if min_area is not None:
area_mask &= (bboxes_area >= min_area)
if max_area is not None:
area_mask &= (bboxes_area < max_area)
return area_mask


class DOTAMeanAP(VOCMeanAP):
"""DOTA evaluation metric.
This metric computes the DOTA mAP (mean Average Precision) with the given
IoU thresholds and scale ranges.
Args:
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
iou_thrs (float | List[float]): IoU thresholds. Defaults to 0.5.
scale_ranges (List[tuple], optional): Scale ranges for evaluating
mAP. If not specified, all bounding boxes would be included in
evaluation. Defaults to None.
num_classes (int, optional): The number of classes. If None, it will be
obtained from the 'CLASSES' field in ``self.dataset_meta``.
Defaults to None.
eval_mode (str): 'area' or '11points', 'area' means calculating the
area under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1].
The PASCAL VOC2007 defaults to use '11points', while PASCAL
VOC2012 defaults to use 'area'.
Defaults to '11points'.
nproc (int): Processes used for computing TP and FP. If nproc
is less than or equal to 1, multiprocessing will not be used.
Defaults to 4.
drop_class_ap (bool): Whether to drop the class without ground truth
when calculating the average precision for each class.
classwise (bool): Whether to return the computed results of each
class. Defaults to False.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
>>> import numpy as np
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
>>> from mmeval import DOTAMetric
>>> num_classes = 15
>>> dota_metric = DOTAMetric(num_classes=15)
>>>
>>> def _gen_bboxes(num_bboxes, img_w=256, img_h=256):
... # random generate bounding boxes in 'xywha' formart.
... x = np.random.rand(num_bboxes, ) * img_w
... y = np.random.rand(num_bboxes, ) * img_h
... w = np.random.rand(num_bboxes, ) * (img_w - x)
... h = np.random.rand(num_bboxes, ) * (img_h - y)
... a = np.random.rand(num_bboxes, ) * np.pi / 2
... return np.stack([x, y, w, h, a], axis=1)
>>> prediction = {
... 'bboxes': _gen_bboxes(10),
... 'scores': np.random.rand(10, ),
... 'labels': np.random.randint(0, num_classes, size=(10, ))
... }
>>> groundtruth = {
... 'bboxes': _gen_bboxes(10),
... 'labels': np.random.randint(0, num_classes, size=(10, )),
... 'bboxes_ignore': _gen_bboxes(5),
... 'labels_ignore': np.random.randint(0, num_classes, size=(5, ))
... }
>>> dota_metric(predictions=[prediction, ], groundtruths=[groundtruth, ]) # doctest: +ELLIPSIS # noqa: E501
{'[email protected]': ..., 'mAP': ...}
"""

def __init__(self, eval_mode: str = '11points', **kwargs) -> None:
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(eval_mode=eval_mode, **kwargs)

def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501

"""Add the intermediate results to ``self._results``.
Args:
predictions (Sequence[Dict]): A sequence of dict. Each dict
representing a detection result for an image, with the
following keys:
- bboxes (numpy.ndarray): Shape (N, 5), the predicted
bounding bboxes of this image, in 'xywha' foramrt.
- scores (numpy.ndarray): Shape (N, ), the predicted scores
of bounding boxes.
- labels (numpy.ndarray): Shape (N, ), the predicted labels
of bounding boxes.
groundtruths (Sequence[Dict]): A sequence of dict. Each dict
represents a groundtruths for an image, with the following
keys:
- bboxes (numpy.ndarray): Shape (M, 5), the ground truth
bounding bboxes of this image, in 'xywha' foramrt.
- labels (numpy.ndarray): Shape (M, ), the ground truth
labels of bounding boxes.
- bboxes_ignore (numpy.ndarray): Shape (K, 5), the ground
truth ignored bounding bboxes of this image,
in 'xywha' foramrt.
- labels_ignore (numpy.ndarray): Shape (K, ), the ground
truth ignored labels of bounding boxes.
"""
for prediction, groundtruth in zip(predictions, groundtruths):
assert isinstance(prediction, dict), 'The prediciton should be ' \
f'a sequence of dict, but got a sequence of {type(prediction)}.' # noqa: E501
assert isinstance(groundtruth, dict), 'The label should be ' \
f'a sequence of dict, but got a sequence of {type(groundtruth)}.' # noqa: E501
self._results.append((prediction, groundtruth))

@staticmethod
def _calculate_image_tpfp( # type: ignore
pred_bboxes: np.ndarray,
gt_bboxes: np.ndarray,
ignore_gt_bboxes: np.ndarray,
iou_thrs: List[float],
area_ranges: List[Tuple[Optional[float], Optional[float]]],
) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate the true positive and false positive on an image.
Args:
pred_bboxes (numpy.ndarray): Predicted bboxes of this image, with
shape (N, 6). The scores The predicted score of the bbox is
concatenated behind the predicted bbox.
gt_bboxes (numpy.ndarray): Ground truth bboxes of this image, with
shape (M, 5).
ignore_gt_bboxes (numpy.ndarray): Ground truth ignored bboxes of
this image, with shape (K, 5).
iou_thrs (List[float]): The IoU thresholds.
area_ranges (List[Tuple]): The area ranges.
Returns:
tuple (tp, fp):
- tp (numpy.ndarray): Shape (num_ious, num_scales, N),
the true positive flag of each predicted bbox on this image.
- fp (numpy.ndarray): Shape (num_ious, num_scales, N),
the false positive flag of each predicted bbox on this image.
Note:
This method should be a staticmethod to avoid resource competition
during multiple processes.
"""
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
# Step 1. Concatenate `gt_bboxes` and `ignore_gt_bboxes`, then set
# the `ignore_gt_flags`.
all_gt_bboxes = np.concatenate((gt_bboxes, ignore_gt_bboxes))
ignore_gt_flags = np.concatenate((np.zeros(
(gt_bboxes.shape[0], 1),
dtype=bool), np.ones((ignore_gt_bboxes.shape[0], 1), dtype=bool)))

# Step 2. Initialize the `tp` and `fp` arrays.
num_preds = pred_bboxes.shape[0]
tp = np.zeros((len(iou_thrs), len(area_ranges), num_preds))
fp = np.zeros((len(iou_thrs), len(area_ranges), num_preds))

# Step 3. If there are no gt bboxes in this image, then all pred bboxes
# within area range are false positives.
if all_gt_bboxes.shape[0] == 0:
for idx, (min_area, max_area) in enumerate(area_ranges):
area_mask = filter_by_bboxes_area_rotated(
pred_bboxes[:, :5], min_area, max_area)
fp[:, idx, area_mask] = 1
return tp, fp

# Step 4. Calculate the IoUs between the predicted bboxes and the
# ground truth bboxes.
ious = bbox_iou_rotated(pred_bboxes[:, :5], all_gt_bboxes)
# For each pred bbox, the max iou with all gts.
ious_max = ious.max(axis=1)
# For each pred bbox, which gt overlaps most with it.
ious_argmax = ious.argmax(axis=1)
# Sort all pred bbox in descending order by scores.
sorted_indices = np.argsort(-pred_bboxes[:, -1])

# Step 5. Count the `tp` and `fp` of each iou threshold and area range.
for iou_thr_idx, iou_thr in enumerate(iou_thrs):
for area_idx, (min_area, max_area) in enumerate(area_ranges):
# The flags that gt bboxes have been matched.
gt_covered_flags = np.zeros(all_gt_bboxes.shape[0], dtype=bool)
# The flags that gt bboxes out of area range.
gt_area_mask = filter_by_bboxes_area_rotated(
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
all_gt_bboxes, min_area, max_area)
ignore_gt_area_flags = ~gt_area_mask

# Count the prediction bboxes in order of decreasing score.
for pred_bbox_idx in sorted_indices:
if ious_max[pred_bbox_idx] >= iou_thr:
matched_gt_idx = ious_argmax[pred_bbox_idx]
# Ignore the pred bbox that match an ignored gt bbox.
if ignore_gt_flags[matched_gt_idx]:
continue
# Ignore the pred bbox that is out of area range.
if ignore_gt_area_flags[matched_gt_idx]:
continue
if not gt_covered_flags[matched_gt_idx]:
tp[iou_thr_idx, area_idx, pred_bbox_idx] = 1
gt_covered_flags[matched_gt_idx] = True
else:
# This gt bbox has been matched and counted as fp.
fp[iou_thr_idx, area_idx, pred_bbox_idx] = 1
else:
area_mask = filter_by_bboxes_area_rotated(
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
pred_bboxes[pred_bbox_idx, :5], min_area, max_area)
if area_mask:
fp[iou_thr_idx, area_idx, pred_bbox_idx] = 1

return tp, fp

def calculate_class_tpfp(self, predictions: List[dict],
groundtruths: List[dict], class_index: int,
pool: Optional[Pool]) -> Tuple:
"""Calculate the tp and fp of the given class index.
Args:
predictions (List[dict]): A list of dict. Each dict is the
detection result of an image.
groundtruths (List[dict]): A list of dict. Each dict is the
ground truth of an image.
class_index (int): The class index.
pool (Optional[Pool]): A instance of :class:`multiprocessing.Pool`.
If None, do not use multiprocessing.
Returns:
tuple (tp, fp, num_gts):
- tp (numpy.ndarray): Shape (num_ious, num_scales, num_pred),
the true positive flag of each predicted bbox for this class.
- fp (numpy.ndarray): Shape (num_ious, num_scales, num_pred),
the false positive flag of each predicted bbox for this class.
- num_gts (numpy.ndarray): Shape (num_ious, num_scales), the
number of ground truths.
"""
class_preds = self.get_class_predictions(predictions, class_index)
class_gts, class_ignore_gts = self.get_class_gts(
groundtruths, class_index)
if pool is not None:
num_images = len(class_preds)
tpfp_list = pool.starmap(
self._calculate_image_tpfp,
zip(
class_preds,
class_gts,
class_ignore_gts,
[self.iou_thrs] * num_images,
[self._area_ranges] * num_images,
))
else:
tpfp_list = []
for img_idx in range(len(class_preds)):
tpfp = self._calculate_image_tpfp(class_preds[img_idx],
class_gts[img_idx],
class_ignore_gts[img_idx],
self.iou_thrs,
self._area_ranges)
tpfp_list.append(tpfp)

image_tp_list, image_fp_list = tuple(zip(*tpfp_list))
sorted_indices = np.argsort(-np.vstack(class_preds)[:, -1])
tp = np.concatenate(image_tp_list, axis=2)[..., sorted_indices]
fp = np.concatenate(image_fp_list, axis=2)[..., sorted_indices]
num_gts = np.zeros((self.num_iou, self.num_scale), dtype=int)
for idx, (min_area, max_area) in enumerate(self._area_ranges):
area_mask = filter_by_bboxes_area_rotated(
np.vstack(class_gts), min_area, max_area)
num_gts[:, idx] = np.sum(area_mask)
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved

return tp, fp, num_gts
93 changes: 93 additions & 0 deletions mmeval/metrics/utils/bbox_iou_rotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np


def le90_to_oc(bboxes: np.ndarray):
"""convert bboxes with le90 version to opencv version.
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved

Args:
bboxes (np.ndarray): The shape of bboxes should be [N,5],
the format is [x,y,w,h,angle]
Returns:
np.ndarray: An numpy.ndarray with the same shape of input.
"""
assert bboxes.shape[1] == 5, 'The boxes shape should be [N,5]'

# a mask to indicate if input angles belong to (0,pi/2]
mask = bboxes[:, 4] <= 0.0
# convert angle
ret_bboxes = bboxes.copy()
ret_bboxes[:, 4] += np.pi / 2 * np.ones(bboxes.shape[0]) * mask
# convert w and h
temp = ret_bboxes[mask]
temp[:, [2, 3]] = temp[:, [3, 2]]
ret_bboxes[mask] = temp
# rad to angle
ret_bboxes[:, 4] = ret_bboxes[:, 4] * 180.0 / np.pi
return ret_bboxes


def calculate_bboxes_area_rotated(bboxes: np.ndarray) -> np.ndarray:
"""Calculate area of rotated bounding boxes.

Args:
bboxes (np.ndarray): The bboxes with shape (n, 5) or (5, )
in 'xywha'format.
Returns:
np.ndarray: The area of bboxes.
"""
bboxes_w = bboxes[..., 2]
bboxes_h = bboxes[..., 3]
areas = bboxes_w * bboxes_h
return areas


def bbox_iou_rotated(bboxes1: np.ndarray,
bboxes2: np.ndarray,
clockwise: bool = True) -> np.ndarray:
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate the overlap between each rotated bbox of bboxes1 and bboxes2.

Args:
bboxes1 (np.ndarray): The bboxes with shape (n, 5) in 'xywha' format.
bboxes2 (np.ndarray): The bboxes with shape (k, 5) in 'xywha' format.
clockwise (bool, optional): flag indicating whether the positive
angular orientation is clockwise. Defaults to True.
Returns:
np.ndarray: IoUs with shape (n, k).
"""
bboxes1 = bboxes1.astype(np.float32)
bboxes2 = bboxes2.astype(np.float32)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
ious = np.zeros((rows, cols), dtype=np.float32)

if rows * cols == 0:
return ious

if not clockwise:
flip_mat = np.ones(bboxes1.shape[-1])
flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat

# convert angle version
bboxes1 = le90_to_oc(bboxes1)
bboxes2 = le90_to_oc(bboxes2)

area1 = bboxes1[:, 2] * bboxes1[:, 3]
area2 = bboxes2[:, 2] * bboxes2[:, 3]
for i, box1 in enumerate(bboxes1):
r1 = ((box1[0], box1[1]), (box1[2], box1[3]), box1[4])
for j, box2 in enumerate(bboxes2):
r2 = ((box2[0], box2[1]), (box2[2], box2[3]), box2[4])
int_pts = cv2.rotatedRectangleIntersection(r1, r2)[1]
if int_pts is not None:
order_pts = cv2.convexHull(int_pts, returnPoints=True)
int_area = cv2.contourArea(order_pts)
inter = int_area * 1.0 / (
area1[i] + area2[j] - int_area + 1e-5)
ious[i][j] = inter
else:
ious[i][j] = 0.0
return ious
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ BASED_ON_STYLE = pep8
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true

# ignore-words-list needs to be lowercase format. For example, if we want to
# ignore word "BA", then we need to append "ba" to ignore-words-list rather
# than "BA"
[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = dota

[mypy]
allow_redefinition = True
Loading