From 71c275614f43605b3f42714cbe25549224d26b39 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Sat, 26 Dec 2020 02:51:18 +0100 Subject: [PATCH] feat: Added automatic layer name resolution (#32) * test: Renamed testers * feat: Added layer resolution utils * feat: Added layer resolution to CAM * test: Added unittests for utils and simplified existing ones * docs: Updated README * refactor: Reflected changes on CAM interface * feat: Updated visualization example script * style: Fixed lint * test: Fixed unittests * feat: Forced eval switch when possible --- README.md | 15 ----- scripts/cam_example.py | 32 ++++++----- test/test_cams.py | 84 ++++++++++++++++------------ test/test_utils.py | 2 +- torchcam/cams/__init__.py | 1 + torchcam/cams/cam.py | 113 ++++++++++++++++++++++++++------------ torchcam/cams/gradcam.py | 39 +++++++------ torchcam/cams/utils.py | 75 +++++++++++++++++++++++++ 8 files changed, 246 insertions(+), 115 deletions(-) create mode 100644 torchcam/cams/utils.py diff --git a/README.md b/README.md index 6c2a7c61..ca659150 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,6 @@ Simple way to leverage the class-specific activation of convolutional layers in * [Prerequisites](#prerequisites) * [Installation](#installation) * [Usage](#usage) -* [Technical Roadmap](#technical-roadmap) * [Documentation](#documentation) * [Contributing](#contributing) * [Credits](#credits) @@ -58,20 +57,6 @@ python scripts/cam_example.py --model resnet50 --class-idx 232 - - -## Technical roadmap - -The project is currently under development, here are the objectives for the next releases: - -- [x] Parallel CAMs: enable batch processing. -- [x] Benchmark: compare class activation map computations for different architectures. -- [ ] Signature improvement: retrieve automatically the specific required layer names. -- [ ] Refined RPN: create a region proposal network using CAM. -- [ ] Task transfer: turn a well-trained classifier into an object detector. - - - ## Documentation The full package documentation is available [here](https://frgfm.github.io/torch-cam/) for detailed specifications. The documentation was built with [Sphinx](sphinx-doc.org) using a [theme](github.com/readthedocs/sphinx_rtd_theme) provided by [Read the Docs](readthedocs.org). diff --git a/scripts/cam_example.py b/scripts/cam_example.py index 4a28e51e..a1ce5a1e 100644 --- a/scripts/cam_example.py +++ b/scripts/cam_example.py @@ -1,10 +1,10 @@ #!usr/bin/python -# -*- coding: utf-8 -*- """ CAM visualization """ +import math import argparse from io import BytesIO @@ -18,18 +18,18 @@ from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM, SSCAM, ISCAM from torchcam.utils import overlay_mask -VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features') +VGG_CONFIG = {_vgg: dict(conv_layer='features') for _vgg in models.vgg.__dict__.keys()} -RESNET_CONFIG = {_resnet: dict(input_layer='conv1', conv_layer='layer4', fc_layer='fc') +RESNET_CONFIG = {_resnet: dict(conv_layer='layer4', fc_layer='fc') for _resnet in models.resnet.__dict__.keys()} -DENSENET_CONFIG = {_densenet: dict(input_layer='features', conv_layer='features', fc_layer='classifier') +DENSENET_CONFIG = {_densenet: dict(conv_layer='features', fc_layer='classifier') for _densenet in models.densenet.__dict__.keys()} MODEL_CONFIG = { **VGG_CONFIG, **RESNET_CONFIG, **DENSENET_CONFIG, - 'mobilenet_v2': dict(input_layer='features', conv_layer='features') + 'mobilenet_v2': dict(conv_layer='features') } @@ -43,7 +43,6 @@ def main(args): # Pretrained imagenet model model = models.__dict__[args.model](pretrained=True).eval().to(device=device) conv_layer = MODEL_CONFIG[args.model]['conv_layer'] - input_layer = MODEL_CONFIG[args.model]['input_layer'] fc_layer = MODEL_CONFIG[args.model]['fc_layer'] # Image @@ -57,15 +56,17 @@ def main(args): # Hook the corresponding layer in the model cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer), - GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer), - ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer), - ISCAM(model, conv_layer, input_layer)] + GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer), + ScoreCAM(model, conv_layer), SSCAM(model, conv_layer), + ISCAM(model, conv_layer)] # Don't trigger all hooks for extractor in cam_extractors: extractor._hooks_enabled = False - fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2)) + num_rows = 2 + num_cols = math.ceil(len(cam_extractors) / num_rows) + _, axes = plt.subplots(num_rows, num_cols, figsize=(6, 4)) for idx, extractor in enumerate(cam_extractors): extractor._hooks_enabled = True model.zero_grad() @@ -76,6 +77,7 @@ def main(args): # Use the hooked data to compute activation map activation_map = extractor(class_idx, scores).cpu() + # Clean data extractor.clear_hooks() extractor._hooks_enabled = False @@ -85,9 +87,13 @@ def main(args): # Plot the result result = overlay_mask(pil_img, heatmap) - axes[idx].imshow(result) - axes[idx].axis('off') - axes[idx].set_title(extractor.__class__.__name__, size=8) + axes[idx // num_cols][idx % num_cols].imshow(result) + axes[idx // num_cols][idx % num_cols].set_title(extractor.__class__.__name__, size=8) + + # Clear axes + for row in axes: + for ax in row: + ax.axis('off') plt.tight_layout() if args.savefig: diff --git a/test/test_cams.py b/test/test_cams.py index 7f7359b9..6655ade7 100644 --- a/test/test_cams.py +++ b/test/test_cams.py @@ -4,6 +4,7 @@ import requests import torch from PIL import Image +from torch import nn from torchvision.models import mobilenet_v2, resnet18 from torchvision.transforms.functional import normalize, resize, to_tensor @@ -20,7 +21,7 @@ def _forward(model, input_tensor): return scores -class Tester(unittest.TestCase): +class CAMCoreTester(unittest.TestCase): def _verify_cam(self, cam): # Simple verifications self.assertIsInstance(cam, torch.Tensor) @@ -67,76 +68,91 @@ def _test_extractor(self, extractor, model): def _test_cam(self, name): # Get a pretrained model - model = resnet18(pretrained=False).eval() - conv_layer = 'layer4' - input_layer = 'conv1' - fc_layer = 'fc' - - # Hook the corresponding layer in the model - extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer) - - self._test_extractor(extractor, model) - - def _test_cam_arbitrary_layer(self, name): - model = resnet18(pretrained=False).eval() conv_layer = 'layer4.1.relu' - input_layer = 'conv1' - fc_layer = 'fc' # Hook the corresponding layer in the model - extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer) + extractor = cams.__dict__[name](model, conv_layer) - self._test_extractor(extractor, model) + with torch.no_grad(): + self._test_extractor(extractor, model) def _test_gradcam(self, name): # Get a pretrained model model = mobilenet_v2(pretrained=False) - conv_layer = 'features' + conv_layer = 'features.17.conv.3' # Hook the corresponding layer in the model extractor = cams.__dict__[name](model, conv_layer) self._test_extractor(extractor, model) - def _test_gradcam_arbitrary_layer(self, name): + def test_smooth_gradcampp(self): - model = mobilenet_v2(pretrained=False) - conv_layer = 'features.17.conv.3' + # Get a pretrained model + model = mobilenet_v2(pretrained=False).eval() # Hook the corresponding layer in the model - extractor = cams.__dict__[name](model, conv_layer) + extractor = cams.SmoothGradCAMpp(model) self._test_extractor(extractor, model) - def test_smooth_gradcampp(self): - # Get a pretrained model - model = mobilenet_v2(pretrained=False) - conv_layer = 'features' - input_layer = 'features' +class CAMUtilsTester(unittest.TestCase): - # Hook the corresponding layer in the model - extractor = cams.SmoothGradCAMpp(model, conv_layer, input_layer) + @staticmethod + def _get_custom_module(): - self._test_extractor(extractor, model) + mod = nn.Sequential( + nn.Sequential( + nn.Conv2d(3, 8, 3, 1), + nn.ReLU(), + nn.Conv2d(8, 16, 3, 1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)) + ), + nn.Flatten(1), + nn.Linear(16, 1) + ) + return mod + + def test_locate_candidate_layer(self): + + # ResNet-18 + mod = resnet18().eval() + self.assertEqual(cams.utils.locate_candidate_layer(mod), 'layer4') + + # Custom model + mod = self._get_custom_module() + + self.assertEqual(cams.utils.locate_candidate_layer(mod), '0.3') + # Check that the model is switched back to its origin mode afterwards + self.assertTrue(mod.training) + + def test_locate_linear_layer(self): + + # ResNet-18 + mod = resnet18().eval() + self.assertEqual(cams.utils.locate_linear_layer(mod), 'fc') + + # Custom model + mod = self._get_custom_module() + self.assertEqual(cams.utils.locate_linear_layer(mod), '2') for cam_extractor in ['CAM', 'ScoreCAM', 'SSCAM', 'ISCAM']: def do_test(self, cam_extractor=cam_extractor): self._test_cam(cam_extractor) - self._test_cam_arbitrary_layer(cam_extractor) - setattr(Tester, "test_" + cam_extractor.lower(), do_test) + setattr(CAMCoreTester, "test_" + cam_extractor.lower(), do_test) for cam_extractor in ['GradCAM', 'GradCAMpp']: def do_test(self, cam_extractor=cam_extractor): self._test_gradcam(cam_extractor) - self._test_gradcam_arbitrary_layer(cam_extractor) - setattr(Tester, "test_" + cam_extractor.lower(), do_test) + setattr(CAMCoreTester, "test_" + cam_extractor.lower(), do_test) if __name__ == '__main__': diff --git a/test/test_utils.py b/test/test_utils.py index c7be1633..1d715c80 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -6,7 +6,7 @@ from torchcam import utils -class Tester(unittest.TestCase): +class UtilsTester(unittest.TestCase): def test_overlay_mask(self): img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8)) diff --git a/torchcam/cams/__init__.py b/torchcam/cams/__init__.py index 41225277..ad26e96e 100644 --- a/torchcam/cams/__init__.py +++ b/torchcam/cams/__init__.py @@ -1,2 +1,3 @@ from .cam import * from .gradcam import * +from .utils import * diff --git a/torchcam/cams/cam.py b/torchcam/cams/cam.py index fe5ae0a8..cffc51aa 100644 --- a/torchcam/cams/cam.py +++ b/torchcam/cams/cam.py @@ -1,9 +1,12 @@ import math +import logging import torch from torch import Tensor from torch import nn import torch.nn.functional as F -from typing import Optional, List +from typing import Optional, List, Tuple + +from .utils import locate_candidate_layer, locate_linear_layer __all__ = ['CAM', 'ScoreCAM', 'SSCAM', 'ISCAM'] @@ -13,26 +16,38 @@ class _CAM: Args: model: input model - conv_layer: name of the last convolutional layer + target_layer: name of the target layer + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: nn.Module, - conv_layer: str + target_layer: Optional[str] = None, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: # Obtain a mapping from module name to module instance for each layer in the model self.submodule_dict = dict(model.named_modules()) - if conv_layer not in self.submodule_dict.keys(): - raise ValueError(f"Unable to find submodule {conv_layer} in the model") + # If the layer is not specified, try automatic resolution + if target_layer is None: + target_layer = locate_candidate_layer(model, input_shape) + # Warn the user of the choice + if isinstance(target_layer, str): + logging.warning(f"no value was provided for `target_layer`, thus set to '{target_layer}'.") + else: + raise ValueError("unable to resolve `target_layer` automatically, please specify its value.") + + if target_layer not in self.submodule_dict.keys(): + raise ValueError(f"Unable to find submodule {target_layer} in the model") + self.target_layer = target_layer self.model = model # Init hooks self.hook_a: Optional[Tensor] = None self.hook_handles: List[torch.utils.hooks.RemovableHandle] = [] # Forward hook - self.hook_handles.append(self.submodule_dict[conv_layer].register_forward_hook(self._hook_a)) + self.hook_handles.append(self.submodule_dict[target_layer].register_forward_hook(self._hook_a)) # Enable hooks self._hooks_enabled = True # Should ReLU be used before normalization @@ -129,7 +144,7 @@ class CAM(_CAM): .. math:: L^{(c)}_{CAM}(x, y) = ReLU\\Big(\\sum\\limits_k w_k^{(c)} A_k(x, y)\\Big) - where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k` in the fully connected layer.. @@ -144,18 +159,29 @@ class CAM(_CAM): Args: model: input model - conv_layer: name of the last convolutional layer + target_layer: name of the target layer fc_layer: name of the fully convolutional layer + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: nn.Module, - conv_layer: str, - fc_layer: str + target_layer: Optional[str] = None, + fc_layer: Optional[str] = None, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: - super().__init__(model, conv_layer) + super().__init__(model, target_layer, input_shape) + + # If the layer is not specified, try automatic resolution + if fc_layer is None: + fc_layer = locate_linear_layer(model) + # Warn the user of the choice + if isinstance(fc_layer, str): + logging.warning(f"no value was provided for `fc_layer`, thus set to '{fc_layer}'.") + else: + raise ValueError("unable to resolve `fc_layer` automatically, please specify its value.") # Softmax weight self._fc_weights = self.submodule_dict[fc_layer].weight.data @@ -180,7 +206,7 @@ class ScoreCAM(_CAM): .. math:: w_k^{(c)} = softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b)) - where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax for input :math:`X`, :math:`X_b` is a baseline image, and :math:`M_k` is defined as follows: @@ -195,29 +221,29 @@ class ScoreCAM(_CAM): >>> from torchvision.models import resnet18 >>> from torchcam.cams import ScoreCAM >>> model = resnet18(pretrained=True).eval() - >>> cam = ScoreCAM(model, 'layer4', 'conv1') + >>> cam = ScoreCAM(model, 'layer4') >>> with torch.no_grad(): out = model(input_tensor) >>> cam(class_idx=100) Args: model: input model - conv_layer: name of the last convolutional layer - input_layer: name of the first layer + target_layer: name of the target layer batch_size: batch size used to forward masked inputs + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: nn.Module, - conv_layer: str, - input_layer: str, - batch_size: int = 32 + target_layer: Optional[str] = None, + batch_size: int = 32, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: - super().__init__(model, conv_layer) + super().__init__(model, target_layer, input_shape) # Input hook - self.hook_handles.append(self.submodule_dict[input_layer].register_forward_pre_hook(self._store_input)) + self.hook_handles.append(model.register_forward_pre_hook(self._store_input)) self.bs = batch_size # Ensure ReLU is applied to CAM before normalization self._relu = True @@ -248,6 +274,9 @@ def _get_weights(self, class_idx: int, scores: Optional[Tensor] = None) -> Tenso # Disable hook updates self._hooks_enabled = False + # Switch to eval + origin_mode = self.model.training + self.model.eval() # Process by chunk (GPU RAM limitation) for idx in range(math.ceil(weights.shape[0] / self.bs)): @@ -258,6 +287,8 @@ def _get_weights(self, class_idx: int, scores: Optional[Tensor] = None) -> Tenso # Reenable hook updates self._hooks_enabled = True + # Put back the model in the correct mode + self.model.training = origin_mode return weights @@ -280,7 +311,7 @@ class SSCAM(ScoreCAM): w_k^{(c)} = \\frac{1}{N} \\sum\\limits_1^N softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b)) where :math:`N` is the number of samples used to smooth the weights, - :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax for input :math:`X`, :math:`X_b` is a baseline image, and :math:`M_k` is defined as follows: @@ -297,30 +328,30 @@ class SSCAM(ScoreCAM): >>> from torchvision.models import resnet18 >>> from torchcam.cams import SSCAM >>> model = resnet18(pretrained=True).eval() - >>> cam = SSCAM(model, 'layer4', 'conv1') + >>> cam = SSCAM(model, 'layer4') >>> with torch.no_grad(): out = model(input_tensor) >>> cam(class_idx=100) Args: model: input model - conv_layer: name of the last convolutional layer - input_layer: name of the first layer + target_layer: name of the target layer batch_size: batch size used to forward masked inputs num_samples: number of noisy samples used for weight computation std: standard deviation of the noise added to the normalized activation + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: nn.Module, - conv_layer: str, - input_layer: str, + target_layer: Optional[str] = None, batch_size: int = 32, num_samples: int = 35, - std: float = 2.0 + std: float = 2.0, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: - super().__init__(model, conv_layer, input_layer, batch_size) + super().__init__(model, target_layer, batch_size, input_shape) self.num_samples = num_samples self.std = std @@ -346,6 +377,9 @@ def _get_weights(self, class_idx: int, scores: Optional[Tensor] = None) -> Tenso # Disable hook updates self._hooks_enabled = False + # Switch to eval + origin_mode = self.model.training + self.model.eval() for _idx in range(self.num_samples): noisy_m = self._input * (upsampled_a + @@ -363,6 +397,8 @@ def _get_weights(self, class_idx: int, scores: Optional[Tensor] = None) -> Tenso # Reenable hook updates self._hooks_enabled = True + # Put back the model in the correct mode + self.model.training = origin_mode return weights @@ -385,7 +421,7 @@ class ISCAM(ScoreCAM): w_k^{(c)} = \\sum\\limits_{i=1}^N \\frac{i}{N} softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b)) where :math:`N` is the number of samples used to smooth the weights, - :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax for input :math:`X`, :math:`X_b` is a baseline image, and :math:`M_k` is defined as follows: @@ -402,28 +438,28 @@ class ISCAM(ScoreCAM): >>> from torchvision.models import resnet18 >>> from torchcam.cams import ISSCAM >>> model = resnet18(pretrained=True).eval() - >>> cam = ISCAM(model, 'layer4', 'conv1') + >>> cam = ISCAM(model, 'layer4') >>> with torch.no_grad(): out = model(input_tensor) >>> cam(class_idx=100) Args: model: input model - conv_layer: name of the last convolutional layer - input_layer: name of the first layer + target_layer: name of the target layer batch_size: batch size used to forward masked inputs num_samples: number of noisy samples used for weight computation + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: nn.Module, - conv_layer: str, - input_layer: str, + target_layer: Optional[str] = None, batch_size: int = 32, - num_samples: int = 10 + num_samples: int = 10, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: - super().__init__(model, conv_layer, input_layer, batch_size) + super().__init__(model, target_layer, batch_size, input_shape) self.num_samples = num_samples @@ -449,6 +485,9 @@ def _get_weights(self, class_idx: int, scores: Optional[Tensor] = None) -> Tenso self._hooks_enabled = False fmap = torch.zeros((upsampled_a.shape[0], *self._input.shape[1:]), dtype=upsampled_a.dtype, device=upsampled_a.device) + # Switch to eval + origin_mode = self.model.training + self.model.eval() for _idx in range(self.num_samples): fmap += (_idx + 1) / self.num_samples * self._input * upsampled_a @@ -463,5 +502,7 @@ def _get_weights(self, class_idx: int, scores: Optional[Tensor] = None) -> Tenso # Reenable hook updates self._hooks_enabled = True + # Put back the model in the correct mode + self.model.training = origin_mode return weights diff --git a/torchcam/cams/gradcam.py b/torchcam/cams/gradcam.py index f43c06f6..7bf44178 100644 --- a/torchcam/cams/gradcam.py +++ b/torchcam/cams/gradcam.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from typing import Optional +from typing import Optional, Tuple from .cam import _CAM @@ -12,16 +12,18 @@ class _GradCAM(_CAM): Args: model: input model - conv_layer: name of the last convolutional layer + target_layer: name of the target layer + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: torch.nn.Module, - conv_layer: str + target_layer: Optional[str] = None, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: - super().__init__(model, conv_layer) + super().__init__(model, target_layer, input_shape) # Init hook self.hook_g: Optional[Tensor] = None # Ensure ReLU is applied before normalization @@ -29,7 +31,7 @@ def __init__( # Model output is used by the extractor self._score_used = True # Backward hook - self.hook_handles.append(self.submodule_dict[conv_layer].register_backward_hook(self._hook_g)) + self.hook_handles.append(self.submodule_dict[self.target_layer].register_backward_hook(self._hook_g)) def _hook_g(self, module: torch.nn.Module, input: Tensor, output: Tensor) -> None: """Gradient hook""" @@ -67,7 +69,7 @@ class GradCAM(_GradCAM): w_k^{(c)} = \\frac{1}{H \\cdot W} \\sum\\limits_{i=1}^H \\sum\\limits_{j=1}^W \\frac{\\partial Y^{(c)}}{\\partial A_k(i, j)} - where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, and :math:`Y^{(c)}` is the model output score for class :math:`c` before softmax. @@ -81,7 +83,8 @@ class GradCAM(_GradCAM): Args: model: input model - conv_layer: name of the last convolutional layer + target_layer: name of the target layer + input_shape: shape of the expected input tensor excluding the batch dimension """ def _get_weights(self, class_idx: int, scores: Tensor) -> Tensor: # type: ignore[override] @@ -109,7 +112,7 @@ class GradCAMpp(_GradCAM): w_k^{(c)} = \\sum\\limits_{i=1}^H \\sum\\limits_{j=1}^W \\alpha_k^{(c)}(i, j) \\cdot ReLU\\Big(\\frac{\\partial Y^{(c)}}{\\partial A_k(i, j)}\\Big) - where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, :math:`Y^{(c)}` is the model output score for class :math:`c` before softmax, and :math:`\\alpha_k^{(c)}(i, j)` being defined as: @@ -132,7 +135,8 @@ class GradCAMpp(_GradCAM): Args: model: input model - conv_layer: name of the last convolutional layer + target_layer: name of the target layer + input_shape: shape of the expected input tensor excluding the batch dimension """ def _get_weights(self, class_idx: int, scores: Tensor) -> Tensor: # type: ignore[override] @@ -166,7 +170,7 @@ class SmoothGradCAMpp(_GradCAM): w_k^{(c)} = \\sum\\limits_{i=1}^H \\sum\\limits_{j=1}^W \\alpha_k^{(c)}(i, j) \\cdot ReLU\\Big(\\frac{\\partial Y^{(c)}}{\\partial A_k(i, j)}\\Big) - where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at position :math:`(x, y)`, :math:`Y^{(c)}` is the model output score for class :math:`c` before softmax, and :math:`\\alpha_k^{(c)}(i, j)` being defined as: @@ -197,24 +201,27 @@ class SmoothGradCAMpp(_GradCAM): Args: model: input model - conv_layer: name of the last convolutional layer + target_layer: name of the target layer + num_samples: number of samples to use for smoothing + std: standard deviation of the noise + input_shape: shape of the expected input tensor excluding the batch dimension """ def __init__( self, model: torch.nn.Module, - conv_layer: str, - first_layer: str, + target_layer: Optional[str] = None, num_samples: int = 4, - std: float = 0.3 + std: float = 0.3, + input_shape: Tuple[int, ...] = (3, 224, 224), ) -> None: - super().__init__(model, conv_layer) + super().__init__(model, target_layer, input_shape) # Model scores is not used by the extractor self._score_used = False # Input hook - self.hook_handles.append(self.submodule_dict[first_layer].register_forward_pre_hook(self._store_input)) + self.hook_handles.append(model.register_forward_pre_hook(self._store_input)) # Noise distribution self.num_samples = num_samples self.std = std diff --git a/torchcam/cams/utils.py b/torchcam/cams/utils.py new file mode 100644 index 00000000..83b9ceb6 --- /dev/null +++ b/torchcam/cams/utils.py @@ -0,0 +1,75 @@ +import torch +from torch import Tensor +from torch import nn +from typing import List, Optional, Tuple +from functools import partial + +__all__ = ['locate_candidate_layer', 'locate_linear_layer'] + + +def locate_candidate_layer(mod: nn.Module, input_shape: Tuple[int, ...] = (3, 224, 224)) -> Optional[str]: + """Attempts to find a candidate layer to use for CAM extraction + + Args: + mod: the module to inspect + input_shape: the expected shape of input tensor excluding the batch dimension + + Returns: + str: the candidate layer for CAM + """ + + # Set module in eval mode + module_mode = mod.training + mod.eval() + + output_shapes: List[Tuple[Optional[str], Tuple[int, ...]]] = [] + + def _record_output_shape(module: nn.Module, input: Tensor, output: Tensor, name: Optional[str] = None) -> None: + """Activation hook""" + output_shapes.append((name, output.shape)) + + hook_handles: List[torch.utils.hooks.RemovableHandle] = [] + # forward hook on all layers + for n, m in mod.named_modules(): + hook_handles.append(m.register_forward_hook(partial(_record_output_shape, name=n))) + + # forward empty + with torch.no_grad(): + _ = mod(torch.rand(1, *input_shape)) + + # Remove all temporary hooks + for handle in hook_handles: + handle.remove() + + # Put back the model in the corresponding mode + mod.training = module_mode + + # Check output shapes + candidate_layer = None + for layer_name, output_shape in output_shapes: + # Stop before flattening or global pooling + if len(output_shape) != (len(input_shape) + 1) or all(v == 1 for v in output_shape[2:]): + break + else: + candidate_layer = layer_name + + return candidate_layer + + +def locate_linear_layer(mod: nn.Module) -> Optional[str]: + """Attempts to find a fully connecter layer to use for CAM extraction + + Args: + mod: the module to inspect + + Returns: + str: the candidate layer + """ + + candidate_layer = None + for layer_name, m in mod.named_modules(): + if isinstance(m, nn.Linear): + candidate_layer = layer_name + break + + return candidate_layer