Skip to content

Commit

Permalink
feat: Added ScoreCAM implementation (#5)
Browse files Browse the repository at this point in the history
* feat: Added implementation of ScoreCAM

* test: Added unittest for ScoreCAM

* docs: Updated README credits

* docs: Updated sphinx documentation

* feat: Added ScoreCAM to example script

* docs: Updated example visualization

* fix: Fixed ScoreCAM input hook

The input hook was still enabled when processing the masked chunks, which yields an incorrect output map.

* refactor: Added hook dynamic disabling option

* docs: Updated example script visualization

* docs: Updated README

* refactor: Refactored SmoothGradCAMpp signature

* test: Reflected changes on SmoothGradCAMpp signature

* docs: Fixed docstring paper reference of ScoreCAM

* fix: Reflected changes of SmoothGradCAMpp signature

* docs: Added examples in docstrings
  • Loading branch information
frgfm authored Mar 24, 2020
1 parent 9021aec commit d41c073
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 19 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ python scripts/cam_example.py --model resnet50 --class-idx 232
The project is currently under development, here are the objectives for the next releases:

- [x] Parallel CAMs: enable batch processing.
- [ ] Benchmark: compare class activation map computations for different architectures.
- [ ] Signature improvement: retrieve automatically the last convolutional layer.
- [ ] Refine RPN: create a region proposal network using CAM.
- [x] Benchmark: compare class activation map computations for different architectures.
- [ ] Signature improvement: retrieve automatically the specific required layer names.
- [ ] Refined RPN: create a region proposal network using CAM.
- [ ] Task transfer: turn a well-trained classifier into an object detector.


Expand All @@ -92,6 +92,7 @@ This project is developed and maintained by the repo owner, but the implementati
- [Grad-CAM](https://arxiv.org/abs/1610.02391): GradCAM paper, generalizing CAM to models without global average pooling.
- [Grad-CAM++](https://arxiv.org/abs/1710.11063): improvement of GradCAM++ for more accurate pixel-level contribution to the activation.
- [Smooth Grad-CAM++](https://arxiv.org/abs/1908.01224): SmoothGrad mechanism coupled with GradCAM.
- [Score-CAM](https://arxiv.org/abs/1910.01279): score-weighting of class activation for better interpretability.



Expand Down
2 changes: 2 additions & 0 deletions docs/source/cams.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Related to activation-based class activation maps.

.. autoclass:: CAM

.. autoclass:: ScoreCAM


Grad-CAM
--------
Expand Down
9 changes: 5 additions & 4 deletions scripts/cam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from torchvision import models
from torchvision.transforms.functional import normalize, resize, to_tensor, to_pil_image

from torchcam.cams.gradcam import _GradCAM
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM
from torchcam.utils import overlay_mask

VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features')
Expand Down Expand Up @@ -58,7 +57,9 @@ def main(args):

# Hook the corresponding layer in the model
cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer),
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer)]
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer),
ScoreCAM(model, conv_layer, input_layer)]

fig, axes = plt.subplots(1, len(cam_extractors))
for idx, extractor in enumerate(cam_extractors):
model.zero_grad()
Expand All @@ -68,7 +69,7 @@ def main(args):
class_idx = out.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx

# Use the hooked data to compute activation map
if isinstance(extractor, _GradCAM):
if isinstance(extractor, (GradCAM, GradCAMpp)):
activation_map = extractor(out, class_idx)[0].cpu().numpy()
else:
activation_map = extractor(class_idx)[0].cpu().numpy()
Expand Down
Binary file modified static/images/cam_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 11 additions & 3 deletions test/test_cams.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def _get_img_tensor():

return img_tensor

def test_cam(self):
def _test_cam(self, name):
# Get a pretrained model
model = resnet18(pretrained=True).eval()
conv_layer = 'layer4'
input_layer = 'conv1'
fc_layer = 'fc'

# Hook the corresponding layer in the model
extractor = cams.CAM(model, conv_layer, fc_layer)
extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer)

# Get a dog image
img_tensor = self._get_img_tensor()
Expand Down Expand Up @@ -83,7 +84,14 @@ def test_smooth_gradcampp(self):
out = model(img_tensor.unsqueeze(0))

# Use the hooked data to compute activation map
self._verify_cam(extractor(out, out[0].argmax().item()))
self._verify_cam(extractor(out[0].argmax().item()))


for cam_extractor in ['CAM', 'ScoreCAM']:
def do_test(self, cam_extractor=cam_extractor):
self._test_cam(cam_extractor)

setattr(Tester, "test_" + cam_extractor.lower(), do_test)


for cam_extractor in ['GradCAM', 'GradCAMpp']:
Expand Down
95 changes: 93 additions & 2 deletions torchcam/cams/cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
GradCAM
"""

import math
import torch
import torch.nn.functional as F

__all__ = ['CAM']
__all__ = ['CAM', 'ScoreCAM']


class _CAM(object):
Expand All @@ -28,9 +30,12 @@ def __init__(self, model, conv_layer):
self.model = model
# Forward hook
self.hook_handles.append(self.model._modules.get(conv_layer).register_forward_hook(self._hook_a))
# Enable hooks
self._hooks_enabled = True

def _hook_a(self, module, input, output):
self.hook_a = output.data
if self._hooks_enabled:
self.hook_a = output.data

def clear_hooks(self):
"""Clear model hooks"""
Expand Down Expand Up @@ -69,6 +74,14 @@ def __repr__(self):
class CAM(_CAM):
"""Implements a class activation map extractor as described in https://arxiv.org/abs/1512.04150
Example::
>>> from torchvision.models import resnet18
>>> from torchcam.cams import CAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = CAM(model, 'layer4', 'fc')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
Expand All @@ -87,3 +100,81 @@ def _get_weights(self, class_idx):

# Take the FC weights of the target class
return self._fc_weights[class_idx, :]


class ScoreCAM(_CAM):
"""Implements a class activation map extractor as described in https://arxiv.org/abs/1910.01279
Example::
>>> from torchvision.models import resnet18
>>> from torchcam.cams import ScoreCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = ScoreCAM(model, 'layer4', 'conv1')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
"""

hook_a = None
hook_handles = []

def __init__(self, model, conv_layer, input_layer, max_batch=32):

super().__init__(model, conv_layer)

# Input hook
self.hook_handles.append(self.model._modules.get(input_layer).register_forward_pre_hook(self._store_input))
self.max_batch = max_batch

def _store_input(self, module, input):

if self._hooks_enabled:
self._input = input[0].data.clone()

def _get_weights(self, class_idx):

# Upsample activation to input_size
# 1 * O * M * N
upsampled_a = F.interpolate(self.hook_a, self._input.shape[-2:], mode='bilinear', align_corners=False)

# Normalize it
upsampled_a = self._normalize(upsampled_a)

# Use it as a mask
# O * I * H * W
masked_input = upsampled_a.squeeze(0).unsqueeze(1) * self._input

# Initialize weights
weights = torch.zeros(masked_input.shape[0], dtype=masked_input.dtype).to(device=masked_input.device)

# Disable hook updates
self._hooks_enabled = False
# Process by chunk (GPU RAM limitation)
for idx in range(math.ceil(weights.shape[0] / self.max_batch)):

selection_slice = slice(idx * self.max_batch, min((idx + 1) * self.max_batch, weights.shape[0]))
with torch.no_grad():
# Get the softmax probabilities of the target class
weights[selection_slice] = F.softmax(self.model(masked_input[selection_slice]), dim=1)[:, class_idx]

# Reenable hook updates
self._hooks_enabled = True

return weights

def __call__(self, class_idx, normalized=True):

# Get map weight
weights = self._get_weights(class_idx)

# Perform the weighted combination to get the CAM
batch_cams = torch.relu((weights.view(-1, 1, 1) * self.hook_a).sum(dim=1))

# Normalize the CAM
if normalized:
batch_cams = self._normalize(batch_cams)

return batch_cams
54 changes: 47 additions & 7 deletions torchcam/cams/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, model, conv_layer):
self.hook_handles.append(self.model._modules.get(conv_layer).register_backward_hook(self._hook_g))

def _hook_g(self, module, input, output):
self.hook_g = output[0].data
if self._hooks_enabled:
self.hook_g = output[0].data

def _backprop(self, output, class_idx):

Expand Down Expand Up @@ -64,6 +65,14 @@ def __call__(self, output, class_idx, normalized=True):
class GradCAM(_GradCAM):
"""Implements a class activation map extractor as described in https://arxiv.org/pdf/1710.11063.pdf
Example::
>>> from torchvision.models import resnet18
>>> from torchcam.cams import GradCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = GradCAM(model, 'layer4')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(out, class_idx=100)
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
Expand All @@ -86,6 +95,14 @@ def _get_weights(self, output, class_idx):
class GradCAMpp(_GradCAM):
"""Implements a class activation map extractor as described in https://arxiv.org/pdf/1710.11063.pdf
Example::
>>> from torchvision.models import resnet18
>>> from torchcam.cams import GradCAMpp
>>> model = resnet18(pretrained=True).eval()
>>> cam = GradCAMpp(model, 'layer4')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(out, class_idx=100)
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
Expand Down Expand Up @@ -114,6 +131,14 @@ class SmoothGradCAMpp(_GradCAM):
"""Implements a class activation map extractor as described in https://arxiv.org/pdf/1908.01224.pdf
with a personal correction to the paper (alpha coefficient numerator)
Example::
>>> from torchvision.models import resnet18
>>> from torchcam.cams import SmoothGradCAMpp
>>> model = resnet18(pretrained=True).eval()
>>> cam = SmoothGradCAMpp(model, 'layer4', 'conv1')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
Expand All @@ -132,19 +157,20 @@ def __init__(self, model, conv_layer, first_layer, num_samples=4, std=0.3):
self.num_samples = num_samples
self.std = std
self._distrib = torch.distributions.normal.Normal(0, self.std)
self._observing = True
# Specific input hook updater
self._ihook_enabled = True

def _store_input(self, module, input):

if self._observing:
if self._ihook_enabled:
self._input = input[0].data.clone()

def _get_weights(self, output, class_idx):
def _get_weights(self, class_idx):

# Disable input update
self._observing = False
self._ihook_enabled = False
# Keep initial activation
init_fmap = self.hook_a.data
init_fmap = self.hook_a.data.clone()
# Initialize our gradient estimates
grad_2, grad_3 = torch.zeros_like(self.hook_a.data), torch.zeros_like(self.hook_a.data)
# Perform the operations N times
Expand All @@ -161,7 +187,7 @@ def _get_weights(self, output, class_idx):
grad_3.add_(self.hook_g.data.pow(3))

# Reenable input update
self._observing = True
self._ihook_enabled = True

# Average the gradient estimates
grad_2.div_(self.num_samples)
Expand All @@ -173,5 +199,19 @@ def _get_weights(self, output, class_idx):
# Apply pixel coefficient in each weight
return alpha.mul(torch.relu(self.hook_g.data)).sum(axis=(2, 3))

def __call__(self, class_idx, normalized=True):

# Get map weight
weights = self._get_weights(class_idx)

# Perform the weighted combination to get the CAM
batch_cams = torch.relu((weights.view(*weights.shape, 1, 1) * self.hook_a).sum(dim=1))

# Normalize the CAM
if normalized:
batch_cams = self._normalize(batch_cams)

return batch_cams

def __repr__(self):
return f"{self.__class__.__name__}(num_samples={self.num_samples}, std={self.std})"

0 comments on commit d41c073

Please sign in to comment.