Skip to content

Commit

Permalink
refactor: Code refactor (#6)
Browse files Browse the repository at this point in the history
* docs: Fixed header

* docs: Fixed docstrings

* refactor: Added argument check and optimized tensor access

* refactor: Reflected changes on __call__ method

* test: Fixed and refactored unittests

* docs: Added docstrings

* style: Fixed lint
  • Loading branch information
frgfm authored Mar 26, 2020
1 parent d41c073 commit 05f02bd
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 108 deletions.
9 changes: 3 additions & 6 deletions scripts/cam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,13 @@ def main(args):
fig, axes = plt.subplots(1, len(cam_extractors))
for idx, extractor in enumerate(cam_extractors):
model.zero_grad()
out = model(img_tensor.unsqueeze(0))
scores = model(img_tensor.unsqueeze(0))

# Select the class index
class_idx = out.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx
class_idx = scores.squeeze(0).argmax().item() if args.class_idx is None else args.class_idx

# Use the hooked data to compute activation map
if isinstance(extractor, (GradCAM, GradCAMpp)):
activation_map = extractor(out, class_idx)[0].cpu().numpy()
else:
activation_map = extractor(class_idx)[0].cpu().numpy()
activation_map = extractor(class_idx, scores).cpu()
# Clean data
extractor.clear_hooks()
# Convert it to PIL image
Expand Down
68 changes: 41 additions & 27 deletions test/test_cams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
from torchcam import cams


def _forward(model, input_tensor):
if model.training:
scores = model(input_tensor)
else:
with torch.no_grad():
scores = model(input_tensor)

return scores


class Tester(unittest.TestCase):
def _verify_cam(self, cam):
# Simple verifications
self.assertIsInstance(cam, torch.Tensor)
self.assertEqual(cam.shape, (1, 7, 7))
self.assertEqual(cam.shape, (7, 7))

@staticmethod
def _get_img_tensor():
Expand All @@ -30,6 +40,31 @@ def _get_img_tensor():

return img_tensor

def _test_extractor(self, extractor, model):

# Check missing forward raises Error
self.assertRaises(AssertionError, extractor, 0)

# Get a dog image
img_tensor = self._get_img_tensor()

# Check that a batch of 2 cannot be accepted
_ = _forward(model, torch.stack((img_tensor, img_tensor)))
self.assertRaises(ValueError, extractor, 0)

# Correct forward
scores = _forward(model, img_tensor.unsqueeze(0))

# Check incorrect class index
self.assertRaises(ValueError, extractor, -1)

# Check missing score
if extractor._score_used:
self.assertRaises(ValueError, extractor, 0)

# Use the hooked data to compute activation map
self._verify_cam(extractor(scores[0].argmax().item(), scores))

def _test_cam(self, name):
# Get a pretrained model
model = resnet18(pretrained=True).eval()
Expand All @@ -40,14 +75,7 @@ def _test_cam(self, name):
# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer)

# Get a dog image
img_tensor = self._get_img_tensor()
# Forward it
with torch.no_grad():
out = model(img_tensor.unsqueeze(0))

# Use the hooked data to compute activation map
self._verify_cam(extractor(out[0].argmax().item()))
self._test_extractor(extractor, model)

def _test_gradcam(self, name):

Expand All @@ -58,33 +86,19 @@ def _test_gradcam(self, name):
# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer)

# Get a dog image
img_tensor = self._get_img_tensor()

# Forward an image
out = model(img_tensor.unsqueeze(0))

# Use the hooked data to compute activation map
self._verify_cam(extractor(out, out[0].argmax().item()))
self._test_extractor(extractor, model)

def test_smooth_gradcampp(self):

# Get a pretrained model
model = mobilenet_v2(pretrained=True)
conv_layer = 'features'
first_layer = 'features'
input_layer = 'features'

# Hook the corresponding layer in the model
extractor = cams.SmoothGradCAMpp(model, conv_layer, first_layer)

# Get a dog image
img_tensor = self._get_img_tensor()

# Forward an image
out = model(img_tensor.unsqueeze(0))
extractor = cams.SmoothGradCAMpp(model, conv_layer, input_layer)

# Use the hooked data to compute activation map
self._verify_cam(extractor(out[0].argmax().item()))
self._test_extractor(extractor, model)


for cam_extractor in ['CAM', 'ScoreCAM']:
Expand Down
90 changes: 66 additions & 24 deletions torchcam/cams/cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-

"""
GradCAM
CAM
"""

import math
Expand Down Expand Up @@ -32,8 +32,13 @@ def __init__(self, model, conv_layer):
self.hook_handles.append(self.model._modules.get(conv_layer).register_forward_hook(self._hook_a))
# Enable hooks
self._hooks_enabled = True
# Should ReLU be used before normalization
self._relu = False
# Model output is used by the extractor
self._score_used = False

def _hook_a(self, module, input, output):
"""Activation hook"""
if self._hooks_enabled:
self.hook_a = output.data

Expand All @@ -44,22 +49,62 @@ def clear_hooks(self):

@staticmethod
def _normalize(cams):
"""CAM normalization"""
cams -= cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1)
cams /= cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1)

return cams

def _get_weights(self, class_idx):
def _get_weights(self, class_idx, scores=None):

raise NotImplementedError

def __call__(self, class_idx, normalized=True):
def _precheck(self, class_idx, scores):
"""Check for invalid computation cases"""

# Check that forward has already occurred
if self.hook_a is None:
raise AssertionError("Inputs need to be forwarded in the model for the conv features to be hooked")
# Check batch size
if self.hook_a.shape[0] != 1:
raise ValueError(f"expected a 1-sized batch to be hooked. Received: {self.hook_a.shape[0]}")

# Check class_idx value
if class_idx < 0:
raise ValueError("Incorrect `class_idx` argument value")

# Check scores arg
if self._score_used and not isinstance(scores, torch.Tensor):
raise ValueError(f"model output scores is required to be passed to compute CAMs")

def __call__(self, class_idx, scores=None, normalized=True):

# Integrity check
self._precheck(class_idx, scores)

# Compute CAM
return self.compute_cams(class_idx, scores, normalized)

def compute_cams(self, class_idx, scores=None, normalized=True):
"""Compute the CAM for a specific output class
Args:
class_idx (int): output class index of the target class whose CAM will be computed
scores (torch.Tensor[1, K], optional): forward output scores of the hooked model
normalized (bool, optional): whether the CAM should be normalized
Returns:
torch.Tensor[M, N]: class activation map of hooked conv layer
"""

# Get map weight
weights = self._get_weights(class_idx)
weights = self._get_weights(class_idx, scores)

# Perform the weighted combination to get the CAM
batch_cams = (weights.view(-1, 1, 1) * self.hook_a).sum(dim=1)
batch_cams = (weights.unsqueeze(-1).unsqueeze(-1) * self.hook_a.squeeze(0)).sum(dim=0)

if self._relu:
batch_cams = F.relu(batch_cams, inplace=True)

# Normalize the CAM
if normalized:
Expand All @@ -85,6 +130,7 @@ class CAM(_CAM):
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
fc_layer (str): name of the fully convolutional layer
"""

hook_a = None
Expand All @@ -96,7 +142,8 @@ def __init__(self, model, conv_layer, fc_layer):
# Softmax weight
self._fc_weights = self.model._modules.get(fc_layer).weight.data

def _get_weights(self, class_idx):
def _get_weights(self, class_idx, scores=None):
"""Computes the weight coefficients of the hooked activation maps"""

# Take the FC weights of the target class
return self._fc_weights[class_idx, :]
Expand All @@ -116,25 +163,31 @@ class ScoreCAM(_CAM):
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
input_layer (str): name of the first layer
batch_size (int, optional): batch size used to forward masked inputs
"""

hook_a = None
hook_handles = []

def __init__(self, model, conv_layer, input_layer, max_batch=32):
def __init__(self, model, conv_layer, input_layer, batch_size=32):

super().__init__(model, conv_layer)

# Input hook
self.hook_handles.append(self.model._modules.get(input_layer).register_forward_pre_hook(self._store_input))
self.max_batch = max_batch
self.bs = batch_size
# Ensure ReLU is applied to CAM before normalization
self._relu = True

def _store_input(self, module, input):
"""Store model input tensor"""

if self._hooks_enabled:
self._input = input[0].data.clone()

def _get_weights(self, class_idx):
def _get_weights(self, class_idx, scores=None):
"""Computes the weight coefficients of the hooked activation maps"""

# Upsample activation to input_size
# 1 * O * M * N
Expand All @@ -153,9 +206,9 @@ def _get_weights(self, class_idx):
# Disable hook updates
self._hooks_enabled = False
# Process by chunk (GPU RAM limitation)
for idx in range(math.ceil(weights.shape[0] / self.max_batch)):
for idx in range(math.ceil(weights.shape[0] / self.bs)):

selection_slice = slice(idx * self.max_batch, min((idx + 1) * self.max_batch, weights.shape[0]))
selection_slice = slice(idx * self.bs, min((idx + 1) * self.bs, weights.shape[0]))
with torch.no_grad():
# Get the softmax probabilities of the target class
weights[selection_slice] = F.softmax(self.model(masked_input[selection_slice]), dim=1)[:, class_idx]
Expand All @@ -165,16 +218,5 @@ def _get_weights(self, class_idx):

return weights

def __call__(self, class_idx, normalized=True):

# Get map weight
weights = self._get_weights(class_idx)

# Perform the weighted combination to get the CAM
batch_cams = torch.relu((weights.view(-1, 1, 1) * self.hook_a).sum(dim=1))

# Normalize the CAM
if normalized:
batch_cams = self._normalize(batch_cams)

return batch_cams
def __repr__(self):
return f"{self.__class__.__name__}(batch_size={self.bs})"
Loading

0 comments on commit 05f02bd

Please sign in to comment.