Skip to content

Commit

Permalink
add shape checking for MS-SSIM input and revise docstring as comment
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Mar 8, 2023
1 parent 40072ed commit fb3387e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 28 deletions.
49 changes: 28 additions & 21 deletions mmeval/metrics/ms_ssim.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple
from scipy import signal
from typing import Dict, List, Sequence, Tuple

from mmeval.core import BaseMetric
from mmeval.utils import try_import
from .utils.image_transforms import reorder_image

if TYPE_CHECKING:
from scipy import signal
else:
signal = try_import('scipy.signal')


class MultiScaleStructureSimilarity(BaseMetric):
"""MS-SSIM (Multi-Scale Structure Similarity) metric.
Expand All @@ -34,13 +29,13 @@ class MultiScaleStructureSimilarity(BaseMetric):
between the maximum the and minimum allowed values).
Defaults to 255.
filter_size (int): Size of blur kernel to use (will be reduced for
small images). Default to 11.
small images). Defaults to 11.
filter_sigma (float): Standard deviation for Gaussian blur kernel (will
be reduced for small images). Default to 1.5.
be reduced for small images). Defaults to 1.5.
k1 (float): Constant used to maintain stability in the SSIM calculation
(0.01 in the original paper). Default to 0.01.
(0.01 in the original paper). Defaults to 0.01.
k2 (float): Constant used to maintain stability in the SSIM calculation
(0.03 in the original paper). Default to 0.03.
(0.03 in the original paper). Defaults to 0.03.
weights (List[float]): List of weights for each level. Defaults to
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]. Noted that the default
weights don't sum to 1.0 but do match the paper / matlab code.
Expand Down Expand Up @@ -84,13 +79,15 @@ def __init__(self,
self.weights = np.array(weights)

def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add PSNR score of batch to ``self._results``
"""Add a bunch of images to calculate metric result.
Args:
predictions (Sequence[np.ndarray]): Predictions of the model. The
number of elements in the Sequence must be divisible by 2.
The channel order of each element should align with
`self.input_order` and the range should be [0, 255].
number of elements in the Sequence must be divisible by 2, and
the width and height of each element must be divisible by 2 **
num_scale (`self.weights.size`). The channel order of each
element should align with `self.input_order` and the range
should be [0, 255].
"""

num_samples = len(predictions)
Expand All @@ -103,6 +100,16 @@ def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf
reorder_image(pred, self.input_order) for pred in predictions[1::2]
]

least_size = 2**self.weights.size
assert all([
sample.shape[0] % least_size == 0 for sample in half1
]), ('The height and width of each sample must be divisible by '
f'{least_size} (2 ** len(self.weights.size)).')
assert all([
sample.shape[0] % least_size == 0 for sample in half2
]), ('The height and width of each sample must be divisible by '
f'{least_size} (2 ** self.weights.size).')

half1 = np.stack(half1, axis=0).astype(np.uint8)
half2 = np.stack(half2, axis=0).astype(np.uint8)

Expand Down Expand Up @@ -131,7 +138,7 @@ def compute_ms_ssim(self, img1: np.array, img2: np.array) -> List[float]:
img2 (ndarray): Images with range [0, 255] and order "NHWC".
Returns:
np.ndarray: MS-SSIM score between `img1` and `img2`.
np.ndarray: MS-SSIM score between `img1` and `img2` of shape (N, ).
"""
if img1.shape != img2.shape:
raise RuntimeError(
Expand Down Expand Up @@ -227,15 +234,15 @@ def _ssim_for_multi_scale(
img2 (np.ndarray): Images with range [0, 255] and order "NHWC".
max_val (int): the dynamic range of the images (i.e., the
difference between the maximum the and minimum allowed
values). Default to 255.
values). Defaults to 255.
filter_size (int): Size of blur kernel to use (will be reduced for
small images). Default to 11.
small images). Defaults to 11.
filter_sigma (float): Standard deviation for Gaussian blur kernel (
will be reduced for small images). Default to 1.5.
will be reduced for small images). Defaults to 1.5.
k1 (float): Constant used to maintain stability in the SSIM
calculation (0.01 in the original paper). Default to 0.01.
calculation (0.01 in the original paper). Defaults to 0.01.
k2 (float): Constant used to maintain stability in the SSIM
calculation (0.03 in the original paper). Default to 0.03.
calculation (0.03 in the original paper). Defaults to 0.03.
Returns:
tuple: Pair containing the mean SSIM and contrast sensitivity
Expand Down
9 changes: 2 additions & 7 deletions mmeval/metrics/swd.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from scipy import ndimage
from typing import Any, Dict, List, Sequence

from mmeval.core import BaseMetric
from mmeval.utils import try_import

if TYPE_CHECKING:
from scipy import ndimage
else:
ndimage = try_import('scipy.ndimage')


class SlicedWassersteinDistance(BaseMetric):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_metrics/test_ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ def test_raise_error():
np.random.randint(0, 255, (64, 64, 3)),
np.random.randint(0, 255, (64, 64, 3))
)

with pytest.raises(AssertionError):
inputs = [np.random.randint(0, 255, (3, 32, 32))] * 3
ms_ssim(inputs)

0 comments on commit fb3387e

Please sign in to comment.