From 7be0b4ff4dd938073c6d90724b30889f23c875b3 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Tue, 24 Mar 2020 02:25:15 +0100 Subject: [PATCH] feat: Added original CAM implementation (#2) * docs: Fixed usage instruction * feat: Added original CAM implementation * chore: Reorganized package * test: Added CAM unittest * refactor: Refactored CAM * refactor: Refactored GradCAMs * refactor: Refactored CAMs * test: Updated unittests accordingly * docs: Updated readme * style: Removed extra blank line * docs: Updated documentation --- README.md | 7 ++- docs/source/cams.rst | 24 ++++++++ docs/source/index.rst | 2 +- docs/source/torchcam.rst | 16 ----- test/test_cams.py | 82 +++++++++++++++++++++++++ test/test_gradcam.py | 55 ----------------- torchcam/__init__.py | 2 +- torchcam/cams/__init__.py | 5 ++ torchcam/cams/cam.py | 80 +++++++++++++++++++++++++ torchcam/cams/gradcam.py | 108 +++++++++++++++++++++++++++++++++ torchcam/gradcam.py | 123 -------------------------------------- 11 files changed, 305 insertions(+), 199 deletions(-) create mode 100644 docs/source/cams.rst delete mode 100644 docs/source/torchcam.rst create mode 100644 test/test_cams.py delete mode 100644 test/test_gradcam.py create mode 100644 torchcam/cams/__init__.py create mode 100644 torchcam/cams/cam.py create mode 100644 torchcam/cams/gradcam.py delete mode 100644 torchcam/gradcam.py diff --git a/README.md b/README.md index a2ce6ca4..7823ea26 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,8 @@ import matplotlib.pyplot as plt from torchvision.models import resnet50 from torchvision.transforms import transforms from torchvision.transforms.functional import to_pil_image -from gradcam import GradCAM, GradCAMpp, overlay_mask +from torchcam.cams import CAM, GradCAM, GradCAMpp +from torchcam.utils import overlay_mask # Pretrained imagenet model @@ -81,7 +82,7 @@ classes = {int(key):value for (key, value) class_idx = 232 # Use the hooked data to compute activation map -activation_maps = gradcam.get_activation_maps(out, class_idx) +activation_maps = gradcam(out, class_idx) # Convert it to PIL image # The indexing below means first image in batch heatmap = to_pil_image(activation_maps[0].cpu().numpy(), mode='F') @@ -101,7 +102,7 @@ plt.imshow(result); plt.axis('off'); plt.title(classes.get(class_idx)); plt.tigh The project is currently under development, here are the objectives for the next releases: -- [ ] Parallel CAMs: enable batch processing. +- [x] Parallel CAMs: enable batch processing. - [ ] Benchmark: compare class activation map computations for different architectures. - [ ] Signature improvement: retrieve automatically the last convolutional layer. - [ ] Refine RPN: create a region proposal network using CAM. diff --git a/docs/source/cams.rst b/docs/source/cams.rst new file mode 100644 index 00000000..595c6459 --- /dev/null +++ b/docs/source/cams.rst @@ -0,0 +1,24 @@ +torchcam.cams +============= + + +.. currentmodule:: torchcam.cams + + +CAM +-------- +Related to activation-based class activation maps. + + +.. autoclass:: CAM + + +Grad-CAM +-------- +Related to gradient-based class activation maps. + + +.. autoclass:: GradCAM + + +.. autoclass:: GradCAMpp diff --git a/docs/source/index.rst b/docs/source/index.rst index 62c203a7..ac35da8b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,7 +7,7 @@ The :mod:`torchcam` package gives PyTorch users the possibility to visualize the :maxdepth: 1 :caption: Package Reference - torchcam + cams utils diff --git a/docs/source/torchcam.rst b/docs/source/torchcam.rst deleted file mode 100644 index 1eafc461..00000000 --- a/docs/source/torchcam.rst +++ /dev/null @@ -1,16 +0,0 @@ -torchcam -========= - - -.. currentmodule:: torchcam - - -Grad-CAM --------- -Related to gradient-based class activation maps. - - -.. autoclass:: GradCAM - - -.. autoclass:: GradCAMpp diff --git a/test/test_cams.py b/test/test_cams.py new file mode 100644 index 00000000..83da7cc4 --- /dev/null +++ b/test/test_cams.py @@ -0,0 +1,82 @@ +import unittest +import requests +from io import BytesIO +from PIL import Image +import torch +from torchvision.models import resnet18, mobilenet_v2 +from torchvision.transforms.functional import resize, to_tensor, normalize + +from torchcam import cams + + +class Tester(unittest.TestCase): + + def _verify_cam(self, cam): + # Simple verifications + self.assertIsInstance(cam, torch.Tensor) + self.assertEqual(cam.shape, (1, 7, 7)) + + @staticmethod + def _get_img_tensor(): + + # Get a dog image + URL = 'https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg' + response = requests.get(URL) + + # Forward an image + pil_img = Image.open(BytesIO(response.content), mode='r').convert('RGB') + img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), + [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + return img_tensor + + def test_cam(self): + # Get a pretrained model + model = resnet18(pretrained=True).eval() + conv_layer = 'layer4' + fc_layer = 'fc' + # Border collie index in ImageNet + class_idx = 232 + + # Hook the corresponding layer in the model + extractor = cams.CAM(model, conv_layer, fc_layer) + + # Get a dog image + img_tensor = self._get_img_tensor() + # Forward it + with torch.no_grad(): + _ = model(img_tensor.unsqueeze(0)) + + # Use the hooked data to compute activation map + self._verify_cam(extractor(class_idx)) + + def _test_gradcam(self, name): + + # Get a pretrained model + model = mobilenet_v2(pretrained=True) + conv_layer = 'features' + # Border collie index in ImageNet + class_idx = 232 + + # Hook the corresponding layer in the model + extractor = cams.__dict__[name](model, conv_layer) + + # Get a dog image + img_tensor = self._get_img_tensor() + + # Forward an image + out = model(img_tensor.unsqueeze(0)) + + # Use the hooked data to compute activation map + self._verify_cam(extractor(out, class_idx)) + + +for cam_extractor in ['GradCAM', 'GradCAMpp']: + def do_test(self, cam_extractor=cam_extractor): + self._test_gradcam(cam_extractor) + + setattr(Tester, "test_" + cam_extractor.lower(), do_test) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_gradcam.py b/test/test_gradcam.py deleted file mode 100644 index 93de740a..00000000 --- a/test/test_gradcam.py +++ /dev/null @@ -1,55 +0,0 @@ -import unittest -import requests -from io import BytesIO -from PIL import Image -import torch -from torchvision.models import mobilenet_v2 -from torchvision.transforms import transforms - -from torchcam import gradcam - - -class Tester(unittest.TestCase): - - def _test_gradcam(self, name): - - # Get a pretrained model - model = mobilenet_v2(pretrained=True) - conv_layer = 'features' - - # Hook the corresponding layer in the model - extractor = gradcam.__dict__[name](model, conv_layer) - - # Get a dog image - URL = 'https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg' - response = requests.get(URL) - - # Forward an image - pil_img = Image.open(BytesIO(response.content), mode='r').convert('RGB') - preprocess = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) - img_tensor = preprocess(pil_img) - out = model(img_tensor.unsqueeze(0)) - - # Border collie index in ImageNet - class_idx = 232 - - # Use the hooked data to compute activation map - activation_map = extractor.get_activation_maps(out, class_idx) - - self.assertIsInstance(activation_map, torch.Tensor) - self.assertEqual(activation_map.shape, (1, 7, 7)) - - -for cam_extractor in ['GradCAM', 'GradCAMpp']: - def do_test(self, cam_extractor=cam_extractor): - self._test_gradcam(cam_extractor) - - setattr(Tester, "test_" + cam_extractor.lower(), do_test) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchcam/__init__.py b/torchcam/__init__.py index 9a3b5db9..23127019 100644 --- a/torchcam/__init__.py +++ b/torchcam/__init__.py @@ -1,4 +1,4 @@ -from torchcam.gradcam import * +from torchcam import cams from torchcam import utils diff --git a/torchcam/cams/__init__.py b/torchcam/cams/__init__.py new file mode 100644 index 00000000..d5ec854b --- /dev/null +++ b/torchcam/cams/__init__.py @@ -0,0 +1,5 @@ +from .cam import * +from .gradcam import * + +del cam +del gradcam diff --git a/torchcam/cams/cam.py b/torchcam/cams/cam.py new file mode 100644 index 00000000..14e24106 --- /dev/null +++ b/torchcam/cams/cam.py @@ -0,0 +1,80 @@ +#!usr/bin/python +# -*- coding: utf-8 -*- + +""" +GradCAM +""" + +import torch + + +__all__ = ['CAM'] + + +class _CAM(object): + """Implements a class activation map extractor + + Args: + model (torch.nn.Module): input model + conv_layer (str): name of the last convolutional layer + """ + + hook_a = None + + def __init__(self, model, conv_layer): + + if not hasattr(model, conv_layer): + raise ValueError(f"Unable to find submodule {conv_layer} in the model") + self.model = model + # Forward hook + self.model._modules.get(conv_layer).register_forward_hook(self._hook_a) + + def _hook_a(self, module, input, output): + self.hook_a = output.data + + @staticmethod + def _normalize(cams): + cams -= cams.flatten(start_dim=1).min().view(-1, 1, 1) + cams /= cams.flatten(start_dim=1).max().view(-1, 1, 1) + + return cams + + def _get_weights(self, class_idx): + + raise NotImplementedError + + def __call__(self, class_idx, normalized=True): + + # Get map weight + weights = self._get_weights(class_idx) + + # Perform the weighted combination to get the CAM + batch_cams = (weights.view(-1, 1, 1) * self.hook_a).sum(dim=1) + + # Normalize the CAM + if normalized: + batch_cams = self._normalize(batch_cams) + + return batch_cams + + +class CAM(_CAM): + """Implements a class activation map extractor as described in https://arxiv.org/abs/1512.04150 + + Args: + model (torch.nn.Module): input model + conv_layer (str): name of the last convolutional layer + """ + + hook_a = None + + def __init__(self, model, conv_layer, fc_layer): + + super().__init__(model, conv_layer) + # Softmax weight + self._fc_weights = self.model._modules.get(fc_layer).weight.data + + def _get_weights(self, class_idx): + + # Take the FC weights of the target class + return self._fc_weights[class_idx, :] diff --git a/torchcam/cams/gradcam.py b/torchcam/cams/gradcam.py new file mode 100644 index 00000000..b9c9b455 --- /dev/null +++ b/torchcam/cams/gradcam.py @@ -0,0 +1,108 @@ +#!usr/bin/python +# -*- coding: utf-8 -*- + +""" +GradCAM +""" + +import torch +from .cam import _CAM + + +__all__ = ['GradCAM', 'GradCAMpp'] + + +class _GradCAM(_CAM): + """Implements a gradient-based class activation map extractor + + Args: + model (torch.nn.Module): input model + conv_layer (str): name of the last convolutional layer + """ + + hook_a, hook_g = None, None + + def __init__(self, model, conv_layer): + + super().__init__(model, conv_layer) + # Backward hook + self.model._modules.get(conv_layer).register_backward_hook(self._hook_g) + + def _hook_g(self, module, input, output): + self.hook_g = output[0].data + + def _backprop(self, output, class_idx): + + if self.hook_a is None: + raise TypeError("Inputs need to be forwarded in the model for the conv features to be hooked") + + # Backpropagate to get the gradients on the hooked layer + loss = output[:, class_idx] + self.model.zero_grad() + loss.backward(retain_graph=True) + + def _get_weights(self, output, class_idx): + + raise NotImplementedError + + def __call__(self, output, class_idx, normalized=True): + + # Backpropagate + self._backprop(output, class_idx) + + # Get map weight + weights = self._get_weights(output, class_idx) + + # Perform the weighted combination to get the CAM + batch_cams = torch.relu((weights.view(*weights.shape, 1, 1) * self.hook_a).sum(dim=1)) + + # Normalize the CAM + if normalized: + batch_cams = self._normalize(batch_cams) + + return batch_cams + + +class GradCAM(_GradCAM): + """Implements a class activation map extractor as described in https://arxiv.org/pdf/1710.11063.pdf + + Args: + model (torch.nn.Module): input model + conv_layer (str): name of the last convolutional layer + """ + + hook_a, hook_g = None, None + + def __init__(self, model, conv_layer): + + super().__init__(model, conv_layer) + + def _get_weights(self, output, class_idx): + + # Global average pool the gradients over spatial dimensions + return self.hook_g.data.mean(axis=(2, 3)) + + +class GradCAMpp(_GradCAM): + """Implements a class activation map extractor as described in https://arxiv.org/pdf/1710.11063.pdf + + Args: + model (torch.nn.Module): input model + conv_layer (str): name of the last convolutional layer + """ + + hook_a, hook_g = None, None + + def __init__(self, model, conv_layer): + + super().__init__(model, conv_layer) + + def _get_weights(self, output, class_idx): + + # Alpha coefficient for each pixel + grad_2 = self.hook_g.data.pow(2) + grad_3 = self.hook_g.data.pow(3) + alpha = grad_2 / (2 * grad_2 + (grad_3 * self.hook_a.data).sum(axis=(2, 3), keepdims=True)) + + # Apply pixel coefficient in each weight + return alpha.mul(torch.relu(self.hook_g.data)).sum(axis=(2, 3)) diff --git a/torchcam/gradcam.py b/torchcam/gradcam.py deleted file mode 100644 index ccf428f4..00000000 --- a/torchcam/gradcam.py +++ /dev/null @@ -1,123 +0,0 @@ -#!usr/bin/python -# -*- coding: utf-8 -*- - -""" -GradCAM -""" - -import torch - - -__all__ = ['GradCAM', 'GradCAMpp'] - - -class GradCAM(object): - """Implements a class activation map extractor as described in https://arxiv.org/pdf/1610.02391.pdf - - Args: - model (torch.nn.Module): input model - conv_layer (str): name of the last convolutional layer - """ - - hook_a, hook_g = None, None - - def __init__(self, model, conv_layer): - - if not hasattr(model, conv_layer): - raise ValueError(f"Unable to find submodule {conv_layer} in the model") - self.model = model - # Forward hook - self.model._modules.get(conv_layer).register_forward_hook(self._hook_a) - # Backward hook - self.model._modules.get(conv_layer).register_backward_hook(self._hook_g) - - def _hook_a(self, module, input, output): - self.hook_a = output.data - - def _hook_g(self, module, input, output): - self.hook_g = output[0].data - - def _compute_gradcams(self, weights, normalized=True): - - # Get the feature activation map - fmap = self.hook_a.data - # Perform the weighted combination to get the CAM - batch_cams = torch.relu((weights.view(*weights.shape, 1, 1) * fmap).sum(dim=1)) - - # Normalize the CAM - if normalized: - batch_cams -= batch_cams.flatten(start_dim=1).min().view(-1, 1, 1) - batch_cams /= batch_cams.flatten(start_dim=1).max().view(-1, 1, 1) - - return batch_cams - - def _backprop(self, output, class_idx): - - if self.hook_a is None: - raise TypeError("Inputs need to be forwarded in the model for the conv features to be hooked") - - # Backpropagate to get the gradients on the hooked layer - loss = output[:, class_idx] - self.model.zero_grad() - loss.backward(retain_graph=True) - - def get_activation_maps(self, output, class_idx, normalized=True): - """Recreate class activation maps - - Args: - output (torch.Tensor[N, K]): output of the hooked model - class_idx (int): class index for expected activation map - normalized (bool, optional): should the activation map be normalized - - Returns: - torch.Tensor[N, H, W]: activation maps of the last forwarded batch at the hooked layer - """ - - # Retrieve the activation and gradients of the target layer - self._backprop(output, class_idx) - - # Global average pool the gradients over spatial dimensions - weights = self.hook_g.data.mean(axis=(2, 3)) - - # Assemble the CAM - return self._compute_gradcams(weights, normalized) - - -class GradCAMpp(GradCAM): - """Implements a class activation map extractor as described in https://arxiv.org/pdf/1710.11063.pdf - - Args: - model (torch.nn.Module): input model - conv_layer (str): name of the last convolutional layer - """ - - hook_a, hook_g = None, None - - def __init__(self, model, conv_layer): - - super().__init__(model, conv_layer) - - def get_activation_maps(self, output, class_idx, normalized=True): - """Recreate class activation maps - - Args: - output (torch.Tensor[N, K]): output of the hooked model - class_idx (int): class index for expected activation map - normalized (bool, optional): should the activation map be normalized - - Returns: - torch.Tensor[N, H, W]: activation maps of the last forwarded batch at the hooked layer - """ - - # Retrieve the activation and gradients of the target layer - self._backprop(output, class_idx) - - # Alpha coefficient for each pixel - grad_2 = self.hook_g.data.pow(2) - grad_3 = self.hook_g.data.pow(3) - alpha = grad_2 / (2 * grad_2 + (grad_3 * self.hook_a.data).sum(axis=(2, 3), keepdims=True)) - # Apply pixel coefficient in each weight - weights = alpha.mul(torch.relu(self.hook_g.data)).sum(axis=(2, 3)) - - # Assemble the CAM - return self._compute_gradcams(weights, normalized)