diff --git a/mmeval/metrics/ms_ssim.py b/mmeval/metrics/ms_ssim.py index 368458ba..4e65f8b7 100644 --- a/mmeval/metrics/ms_ssim.py +++ b/mmeval/metrics/ms_ssim.py @@ -1,16 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple +from scipy import signal +from typing import Dict, List, Sequence, Tuple from mmeval.core import BaseMetric -from mmeval.utils import try_import from .utils.image_transforms import reorder_image -if TYPE_CHECKING: - from scipy import signal -else: - signal = try_import('scipy.signal') - class MultiScaleStructureSimilarity(BaseMetric): """MS-SSIM (Multi-Scale Structure Similarity) metric. @@ -34,13 +29,13 @@ class MultiScaleStructureSimilarity(BaseMetric): between the maximum the and minimum allowed values). Defaults to 255. filter_size (int): Size of blur kernel to use (will be reduced for - small images). Default to 11. + small images). Defaults to 11. filter_sigma (float): Standard deviation for Gaussian blur kernel (will - be reduced for small images). Default to 1.5. + be reduced for small images). Defaults to 1.5. k1 (float): Constant used to maintain stability in the SSIM calculation - (0.01 in the original paper). Default to 0.01. + (0.01 in the original paper). Defaults to 0.01. k2 (float): Constant used to maintain stability in the SSIM calculation - (0.03 in the original paper). Default to 0.03. + (0.03 in the original paper). Defaults to 0.03. weights (List[float]): List of weights for each level. Defaults to [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]. Noted that the default weights don't sum to 1.0 but do match the paper / matlab code. @@ -84,13 +79,15 @@ def __init__(self, self.weights = np.array(weights) def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 - """Add PSNR score of batch to ``self._results`` + """Add a bunch of images to calculate metric result. Args: predictions (Sequence[np.ndarray]): Predictions of the model. The - number of elements in the Sequence must be divisible by 2. - The channel order of each element should align with - `self.input_order` and the range should be [0, 255]. + number of elements in the Sequence must be divisible by 2, and + the width and height of each element must be divisible by 2 ** + num_scale (`self.weights.size`). The channel order of each + element should align with `self.input_order` and the range + should be [0, 255]. """ num_samples = len(predictions) @@ -103,6 +100,16 @@ def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf reorder_image(pred, self.input_order) for pred in predictions[1::2] ] + least_size = 2**self.weights.size + assert all([ + sample.shape[0] % least_size == 0 for sample in half1 + ]), ('The height and width of each sample must be divisible by ' + f'{least_size} (2 ** len(self.weights.size)).') + assert all([ + sample.shape[0] % least_size == 0 for sample in half2 + ]), ('The height and width of each sample must be divisible by ' + f'{least_size} (2 ** self.weights.size).') + half1 = np.stack(half1, axis=0).astype(np.uint8) half2 = np.stack(half2, axis=0).astype(np.uint8) @@ -131,7 +138,7 @@ def compute_ms_ssim(self, img1: np.array, img2: np.array) -> List[float]: img2 (ndarray): Images with range [0, 255] and order "NHWC". Returns: - np.ndarray: MS-SSIM score between `img1` and `img2`. + np.ndarray: MS-SSIM score between `img1` and `img2` of shape (N, ). """ if img1.shape != img2.shape: raise RuntimeError( @@ -227,15 +234,15 @@ def _ssim_for_multi_scale( img2 (np.ndarray): Images with range [0, 255] and order "NHWC". max_val (int): the dynamic range of the images (i.e., the difference between the maximum the and minimum allowed - values). Default to 255. + values). Defaults to 255. filter_size (int): Size of blur kernel to use (will be reduced for - small images). Default to 11. + small images). Defaults to 11. filter_sigma (float): Standard deviation for Gaussian blur kernel ( - will be reduced for small images). Default to 1.5. + will be reduced for small images). Defaults to 1.5. k1 (float): Constant used to maintain stability in the SSIM - calculation (0.01 in the original paper). Default to 0.01. + calculation (0.01 in the original paper). Defaults to 0.01. k2 (float): Constant used to maintain stability in the SSIM - calculation (0.03 in the original paper). Default to 0.03. + calculation (0.03 in the original paper). Defaults to 0.03. Returns: tuple: Pair containing the mean SSIM and contrast sensitivity diff --git a/mmeval/metrics/swd.py b/mmeval/metrics/swd.py index b4d89e11..7cbc89fb 100644 --- a/mmeval/metrics/swd.py +++ b/mmeval/metrics/swd.py @@ -1,14 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from typing import TYPE_CHECKING, Any, Dict, List, Sequence +from scipy import ndimage +from typing import Any, Dict, List, Sequence from mmeval.core import BaseMetric -from mmeval.utils import try_import - -if TYPE_CHECKING: - from scipy import ndimage -else: - ndimage = try_import('scipy.ndimage') class SlicedWassersteinDistance(BaseMetric): diff --git a/tests/test_metrics/test_ms_ssim.py b/tests/test_metrics/test_ms_ssim.py index 8fcbad25..15cc63b8 100644 --- a/tests/test_metrics/test_ms_ssim.py +++ b/tests/test_metrics/test_ms_ssim.py @@ -58,3 +58,7 @@ def test_raise_error(): np.random.randint(0, 255, (64, 64, 3)), np.random.randint(0, 255, (64, 64, 3)) ) + + with pytest.raises(AssertionError): + inputs = [np.random.randint(0, 255, (3, 32, 32))] * 3 + ms_ssim(inputs)