Skip to content

Commit

Permalink
feat: Added original CAM implementation (#2)
Browse files Browse the repository at this point in the history
* docs: Fixed usage instruction

* feat: Added original CAM implementation

* chore: Reorganized package

* test: Added CAM unittest

* refactor: Refactored CAM

* refactor: Refactored GradCAMs

* refactor: Refactored CAMs

* test: Updated unittests accordingly

* docs: Updated readme

* style: Removed extra blank line

* docs: Updated documentation
  • Loading branch information
frgfm authored Mar 24, 2020
1 parent 29bb8f3 commit 7be0b4f
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 199 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ import matplotlib.pyplot as plt
from torchvision.models import resnet50
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_pil_image
from gradcam import GradCAM, GradCAMpp, overlay_mask
from torchcam.cams import CAM, GradCAM, GradCAMpp
from torchcam.utils import overlay_mask


# Pretrained imagenet model
Expand Down Expand Up @@ -81,7 +82,7 @@ classes = {int(key):value for (key, value)
class_idx = 232

# Use the hooked data to compute activation map
activation_maps = gradcam.get_activation_maps(out, class_idx)
activation_maps = gradcam(out, class_idx)
# Convert it to PIL image
# The indexing below means first image in batch
heatmap = to_pil_image(activation_maps[0].cpu().numpy(), mode='F')
Expand All @@ -101,7 +102,7 @@ plt.imshow(result); plt.axis('off'); plt.title(classes.get(class_idx)); plt.tigh

The project is currently under development, here are the objectives for the next releases:

- [ ] Parallel CAMs: enable batch processing.
- [x] Parallel CAMs: enable batch processing.
- [ ] Benchmark: compare class activation map computations for different architectures.
- [ ] Signature improvement: retrieve automatically the last convolutional layer.
- [ ] Refine RPN: create a region proposal network using CAM.
Expand Down
24 changes: 24 additions & 0 deletions docs/source/cams.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
torchcam.cams
=============


.. currentmodule:: torchcam.cams


CAM
--------
Related to activation-based class activation maps.


.. autoclass:: CAM


Grad-CAM
--------
Related to gradient-based class activation maps.


.. autoclass:: GradCAM


.. autoclass:: GradCAMpp
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The :mod:`torchcam` package gives PyTorch users the possibility to visualize the
:maxdepth: 1
:caption: Package Reference

torchcam
cams
utils


Expand Down
16 changes: 0 additions & 16 deletions docs/source/torchcam.rst

This file was deleted.

82 changes: 82 additions & 0 deletions test/test_cams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
import requests
from io import BytesIO
from PIL import Image
import torch
from torchvision.models import resnet18, mobilenet_v2
from torchvision.transforms.functional import resize, to_tensor, normalize

from torchcam import cams


class Tester(unittest.TestCase):

def _verify_cam(self, cam):
# Simple verifications
self.assertIsInstance(cam, torch.Tensor)
self.assertEqual(cam.shape, (1, 7, 7))

@staticmethod
def _get_img_tensor():

# Get a dog image
URL = 'https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg'
response = requests.get(URL)

# Forward an image
pil_img = Image.open(BytesIO(response.content), mode='r').convert('RGB')
img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))),
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

return img_tensor

def test_cam(self):
# Get a pretrained model
model = resnet18(pretrained=True).eval()
conv_layer = 'layer4'
fc_layer = 'fc'
# Border collie index in ImageNet
class_idx = 232

# Hook the corresponding layer in the model
extractor = cams.CAM(model, conv_layer, fc_layer)

# Get a dog image
img_tensor = self._get_img_tensor()
# Forward it
with torch.no_grad():
_ = model(img_tensor.unsqueeze(0))

# Use the hooked data to compute activation map
self._verify_cam(extractor(class_idx))

def _test_gradcam(self, name):

# Get a pretrained model
model = mobilenet_v2(pretrained=True)
conv_layer = 'features'
# Border collie index in ImageNet
class_idx = 232

# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer)

# Get a dog image
img_tensor = self._get_img_tensor()

# Forward an image
out = model(img_tensor.unsqueeze(0))

# Use the hooked data to compute activation map
self._verify_cam(extractor(out, class_idx))


for cam_extractor in ['GradCAM', 'GradCAMpp']:
def do_test(self, cam_extractor=cam_extractor):
self._test_gradcam(cam_extractor)

setattr(Tester, "test_" + cam_extractor.lower(), do_test)


if __name__ == '__main__':
unittest.main()
55 changes: 0 additions & 55 deletions test/test_gradcam.py

This file was deleted.

2 changes: 1 addition & 1 deletion torchcam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torchcam.gradcam import *
from torchcam import cams
from torchcam import utils


Expand Down
5 changes: 5 additions & 0 deletions torchcam/cams/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .cam import *
from .gradcam import *

del cam
del gradcam
80 changes: 80 additions & 0 deletions torchcam/cams/cam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!usr/bin/python
# -*- coding: utf-8 -*-

"""
GradCAM
"""

import torch


__all__ = ['CAM']


class _CAM(object):
"""Implements a class activation map extractor
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
"""

hook_a = None

def __init__(self, model, conv_layer):

if not hasattr(model, conv_layer):
raise ValueError(f"Unable to find submodule {conv_layer} in the model")
self.model = model
# Forward hook
self.model._modules.get(conv_layer).register_forward_hook(self._hook_a)

def _hook_a(self, module, input, output):
self.hook_a = output.data

@staticmethod
def _normalize(cams):
cams -= cams.flatten(start_dim=1).min().view(-1, 1, 1)
cams /= cams.flatten(start_dim=1).max().view(-1, 1, 1)

return cams

def _get_weights(self, class_idx):

raise NotImplementedError

def __call__(self, class_idx, normalized=True):

# Get map weight
weights = self._get_weights(class_idx)

# Perform the weighted combination to get the CAM
batch_cams = (weights.view(-1, 1, 1) * self.hook_a).sum(dim=1)

# Normalize the CAM
if normalized:
batch_cams = self._normalize(batch_cams)

return batch_cams


class CAM(_CAM):
"""Implements a class activation map extractor as described in https://arxiv.org/abs/1512.04150
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
"""

hook_a = None

def __init__(self, model, conv_layer, fc_layer):

super().__init__(model, conv_layer)
# Softmax weight
self._fc_weights = self.model._modules.get(fc_layer).weight.data

def _get_weights(self, class_idx):

# Take the FC weights of the target class
return self._fc_weights[class_idx, :]
Loading

0 comments on commit 7be0b4f

Please sign in to comment.