Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NMS] - add segmentation models support #847

Merged
merged 24 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6a9c013
Added nms for segmentation algorithm, added unit test, added init, ad…
AdonaiVera Feb 3, 2024
aaf3621
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 3, 2024
ebe2785
Removed issue of length sentence
AdonaiVera Feb 3, 2024
1e87a99
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 3, 2024
9ee42e5
mask_iou_batch speed test
SkalskiP Feb 6, 2024
cf63c83
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 6, 2024
fd70f3c
Applied inclusion-exclusion principle to mask_iou_batch and resize fu…
AdonaiVera Feb 7, 2024
879cf6d
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 7, 2024
35c9ebd
refactored `mask_non_max_suppression`
SkalskiP Feb 7, 2024
239f733
refactored `mask_non_max_suppression`
SkalskiP Feb 7, 2024
c98dbbf
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 7, 2024
7e44838
refactored `mask_non_max_suppression`
SkalskiP Feb 7, 2024
c46a2e9
Merge remote-tracking branch 'origin/segmentation_nms' into segmentat…
SkalskiP Feb 7, 2024
8f6887d
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 7, 2024
f2e65cf
refactored `mask_non_max_suppression`
SkalskiP Feb 7, 2024
619fa83
Merge remote-tracking branch 'origin/segmentation_nms' into segmentat…
SkalskiP Feb 7, 2024
b681274
Merge remote-tracking branch 'origin/segmentation_nms' into segmentat…
SkalskiP Feb 7, 2024
71836ed
Merge remote-tracking branch 'origin/segmentation_nms' into segmentat…
SkalskiP Feb 7, 2024
98e7310
test vectorized mask_non_max_suppression
SkalskiP Feb 7, 2024
bbe0bfc
final tests
SkalskiP Feb 7, 2024
cbbfb45
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Feb 7, 2024
79a13d0
final tests
SkalskiP Feb 7, 2024
36ec1b6
Merge remote-tracking branch 'origin/segmentation_nms' into segmentat…
SkalskiP Feb 7, 2024
824a914
tests fixed
SkalskiP Feb 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,22 @@ comments: true
:::supervision.detection.utils.box_iou_batch
AdonaiVera marked this conversation as resolved.
Show resolved Hide resolved

<div class="md-typeset">
<h2>non_max_suppression</h2>
<h2>mask_iou_batch</h2>
</div>

:::supervision.detection.utils.non_max_suppression
:::supervision.detection.utils.mask_iou_batch

<div class="md-typeset">
<h2>box_non_max_suppression</h2>
</div>

:::supervision.detection.utils.box_non_max_suppression

<div class="md-typeset">
<h2>mask_non_max_suppression</h2>
</div>

:::supervision.detection.utils.mask_non_max_suppression

<div class="md-typeset">
<h2>polygon_to_mask</h2>
Expand Down
4 changes: 3 additions & 1 deletion supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@
from supervision.detection.tools.smoother import DetectionsSmoother
from supervision.detection.utils import (
box_iou_batch,
AdonaiVera marked this conversation as resolved.
Show resolved Hide resolved
box_non_max_suppression,
calculate_masks_centroids,
filter_polygons_by_area,
mask_iou_batch,
mask_non_max_suppression,
mask_to_polygons,
mask_to_xyxy,
move_boxes,
non_max_suppression,
polygon_to_mask,
polygon_to_xyxy,
scale_boxes,
Expand Down
36 changes: 23 additions & 13 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.utils import (
box_non_max_suppression,
calculate_masks_centroids,
extract_ultralytics_masks,
get_data_item,
is_data_equal,
mask_non_max_suppression,
merge_data,
non_max_suppression,
process_roboflow_result,
validate_detections_fields,
xywh_to_xyxy,
Expand Down Expand Up @@ -1001,7 +1002,8 @@ def with_nms(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform non-maximum suppression on the current set of object detections.
Performs non-max suppression on detection set. If the detections result
from a segmentation model, the IoU mask is applied. Otherwise, box IoU is used.

Args:
threshold (float, optional): The intersection-over-union threshold
Expand All @@ -1028,18 +1030,26 @@ def with_nms(

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
indices = non_max_suppression(
predictions=predictions, iou_threshold=threshold
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you"
" intended to perform class agnostic NMS set class_agnostic=True."
)
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)
return self[indices]

assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you intended"
" to perform class agnostic NMS set class_agnostic=True."
)
if self.mask is not None:
indices = mask_non_max_suppression(
predictions=predictions, masks=self.mask, iou_threshold=threshold
)
else:
indices = box_non_max_suppression(
predictions=predictions, iou_threshold=threshold
)

predictions = np.hstack(
(self.xyxy, self.confidence.reshape(-1, 1), self.class_id.reshape(-1, 1))
)
indices = non_max_suppression(predictions=predictions, iou_threshold=threshold)
return self[indices]
114 changes: 113 additions & 1 deletion supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,119 @@ def box_area(box):
return area_inter / (area_true[:, None] + area_detection - area_inter)


def non_max_suppression(
def mask_iou_batch(masks_true: np.ndarray, masks_detection: np.ndarray) -> np.ndarray:
"""
Compute Intersection over Union (IoU) of two sets of masks -
`masks_true` and `masks_detection`.

Args:
masks_true (np.ndarray): 3D `np.ndarray` representing ground-truth masks.
masks_detection (np.ndarray): 3D `np.ndarray` representing detection masks.

Returns:
np.ndarray: Pairwise IoU of masks from `masks_true` and `masks_detection`.
"""
intersection_area = np.logical_and(masks_true[:, None], masks_detection).sum(
axis=(2, 3)
)
masks_true_area = masks_true.sum(axis=(1, 2))
masks_detection_area = masks_detection.sum(axis=(1, 2))

union_area = masks_true_area[:, None] + masks_detection_area - intersection_area

return np.divide(
intersection_area,
union_area,
out=np.zeros_like(intersection_area, dtype=float),
where=union_area != 0,
)


def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
"""
Resize all masks in the array to have a maximum dimension of max_dimension,
maintaining aspect ratio.

Args:
masks (np.ndarray): 3D array of binary masks with shape (N, H, W).
max_dimension (int): The maximum dimension for the resized masks.

Returns:
np.ndarray: Array of resized masks.
"""
max_height = np.max(masks.shape[1])
max_width = np.max(masks.shape[2])
scale = min(max_dimension / max_height, max_dimension / max_width)

new_height = int(scale * max_height)
new_width = int(scale * max_width)

x = np.linspace(0, max_width - 1, new_width).astype(int)
y = np.linspace(0, max_height - 1, new_height).astype(int)
xv, yv = np.meshgrid(x, y)

resized_masks = masks[:, yv, xv]

resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width)
return resized_masks


def mask_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float = 0.5,
mask_dimension: int = 640,
) -> np.ndarray:
"""
Perform Non-Maximum Suppression (NMS) on segmentation predictions.

Args:
predictions (np.ndarray): A 2D array of object detection predictions in
the format of `(x_min, y_min, x_max, y_max, score)`
or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
`(N, 6)`, where N is the number of predictions.
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
dimensions of each mask.
iou_threshold (float, optional): The intersection-over-union threshold
to use for non-maximum suppression.
mask_dimension (int, optional): The dimension to which the masks should be
resized before computing IOU values. Defaults to 640.

Returns:
np.ndarray: A boolean array indicating which predictions to keep after
non-maximum suppression.

Raises:
AssertionError: If `iou_threshold` is not within the closed
range from `0` to `1`.
"""
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
rows, columns = predictions.shape

if columns == 5:
predictions = np.c_[predictions, np.zeros(rows)]

sort_index = predictions[:, 4].argsort()[::-1]
predictions = predictions[sort_index]
masks = masks[sort_index]
masks_resized = resize_masks(masks, mask_dimension)
ious = mask_iou_batch(masks_resized, masks_resized)
categories = predictions[:, 5]

keep = np.ones(rows, dtype=bool)
for i in range(rows):
if keep[i]:
condition = (ious[i] > iou_threshold) & (categories[i] == categories)
keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :])

return keep[sort_index.argsort()]


def box_non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5
) -> np.ndarray:
"""
Expand Down
Loading