Skip to content

Commit

Permalink
feat: Added automatic layer name resolution (#32)
Browse files Browse the repository at this point in the history
* test: Renamed testers

* feat: Added layer resolution utils

* feat: Added layer resolution to CAM

* test: Added unittests for utils and simplified existing ones

* docs: Updated README

* refactor: Reflected changes on CAM interface

* feat: Updated visualization example script

* style: Fixed lint

* test: Fixed unittests

* feat: Forced eval switch when possible
  • Loading branch information
frgfm authored Dec 26, 2020
1 parent 16e7641 commit 71c2756
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 115 deletions.
15 changes: 0 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Simple way to leverage the class-specific activation of convolutional layers in
* [Prerequisites](#prerequisites)
* [Installation](#installation)
* [Usage](#usage)
* [Technical Roadmap](#technical-roadmap)
* [Documentation](#documentation)
* [Contributing](#contributing)
* [Credits](#credits)
Expand Down Expand Up @@ -58,20 +57,6 @@ python scripts/cam_example.py --model resnet50 --class-idx 232





## Technical roadmap

The project is currently under development, here are the objectives for the next releases:

- [x] Parallel CAMs: enable batch processing.
- [x] Benchmark: compare class activation map computations for different architectures.
- [ ] Signature improvement: retrieve automatically the specific required layer names.
- [ ] Refined RPN: create a region proposal network using CAM.
- [ ] Task transfer: turn a well-trained classifier into an object detector.



## Documentation

The full package documentation is available [here](https://frgfm.github.io/torch-cam/) for detailed specifications. The documentation was built with [Sphinx](sphinx-doc.org) using a [theme](github.com/readthedocs/sphinx_rtd_theme) provided by [Read the Docs](readthedocs.org).
Expand Down
32 changes: 19 additions & 13 deletions scripts/cam_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!usr/bin/python
# -*- coding: utf-8 -*-

"""
CAM visualization
"""

import math
import argparse
from io import BytesIO

Expand All @@ -18,18 +18,18 @@
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM, SSCAM, ISCAM
from torchcam.utils import overlay_mask

VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features')
VGG_CONFIG = {_vgg: dict(conv_layer='features')
for _vgg in models.vgg.__dict__.keys()}

RESNET_CONFIG = {_resnet: dict(input_layer='conv1', conv_layer='layer4', fc_layer='fc')
RESNET_CONFIG = {_resnet: dict(conv_layer='layer4', fc_layer='fc')
for _resnet in models.resnet.__dict__.keys()}

DENSENET_CONFIG = {_densenet: dict(input_layer='features', conv_layer='features', fc_layer='classifier')
DENSENET_CONFIG = {_densenet: dict(conv_layer='features', fc_layer='classifier')
for _densenet in models.densenet.__dict__.keys()}

MODEL_CONFIG = {
**VGG_CONFIG, **RESNET_CONFIG, **DENSENET_CONFIG,
'mobilenet_v2': dict(input_layer='features', conv_layer='features')
'mobilenet_v2': dict(conv_layer='features')
}


Expand All @@ -43,7 +43,6 @@ def main(args):
# Pretrained imagenet model
model = models.__dict__[args.model](pretrained=True).eval().to(device=device)
conv_layer = MODEL_CONFIG[args.model]['conv_layer']
input_layer = MODEL_CONFIG[args.model]['input_layer']
fc_layer = MODEL_CONFIG[args.model]['fc_layer']

# Image
Expand All @@ -57,15 +56,17 @@ def main(args):

# Hook the corresponding layer in the model
cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer),
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer),
ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer),
ISCAM(model, conv_layer, input_layer)]
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer),
ScoreCAM(model, conv_layer), SSCAM(model, conv_layer),
ISCAM(model, conv_layer)]

# Don't trigger all hooks
for extractor in cam_extractors:
extractor._hooks_enabled = False

fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2))
num_rows = 2
num_cols = math.ceil(len(cam_extractors) / num_rows)
_, axes = plt.subplots(num_rows, num_cols, figsize=(6, 4))
for idx, extractor in enumerate(cam_extractors):
extractor._hooks_enabled = True
model.zero_grad()
Expand All @@ -76,6 +77,7 @@ def main(args):

# Use the hooked data to compute activation map
activation_map = extractor(class_idx, scores).cpu()

# Clean data
extractor.clear_hooks()
extractor._hooks_enabled = False
Expand All @@ -85,9 +87,13 @@ def main(args):
# Plot the result
result = overlay_mask(pil_img, heatmap)

axes[idx].imshow(result)
axes[idx].axis('off')
axes[idx].set_title(extractor.__class__.__name__, size=8)
axes[idx // num_cols][idx % num_cols].imshow(result)
axes[idx // num_cols][idx % num_cols].set_title(extractor.__class__.__name__, size=8)

# Clear axes
for row in axes:
for ax in row:
ax.axis('off')

plt.tight_layout()
if args.savefig:
Expand Down
84 changes: 50 additions & 34 deletions test/test_cams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests
import torch
from PIL import Image
from torch import nn
from torchvision.models import mobilenet_v2, resnet18
from torchvision.transforms.functional import normalize, resize, to_tensor

Expand All @@ -20,7 +21,7 @@ def _forward(model, input_tensor):
return scores


class Tester(unittest.TestCase):
class CAMCoreTester(unittest.TestCase):
def _verify_cam(self, cam):
# Simple verifications
self.assertIsInstance(cam, torch.Tensor)
Expand Down Expand Up @@ -67,76 +68,91 @@ def _test_extractor(self, extractor, model):

def _test_cam(self, name):
# Get a pretrained model
model = resnet18(pretrained=False).eval()
conv_layer = 'layer4'
input_layer = 'conv1'
fc_layer = 'fc'

# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer)

self._test_extractor(extractor, model)

def _test_cam_arbitrary_layer(self, name):

model = resnet18(pretrained=False).eval()
conv_layer = 'layer4.1.relu'
input_layer = 'conv1'
fc_layer = 'fc'

# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer, fc_layer if name == 'CAM' else input_layer)
extractor = cams.__dict__[name](model, conv_layer)

self._test_extractor(extractor, model)
with torch.no_grad():
self._test_extractor(extractor, model)

def _test_gradcam(self, name):

# Get a pretrained model
model = mobilenet_v2(pretrained=False)
conv_layer = 'features'
conv_layer = 'features.17.conv.3'

# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer)

self._test_extractor(extractor, model)

def _test_gradcam_arbitrary_layer(self, name):
def test_smooth_gradcampp(self):

model = mobilenet_v2(pretrained=False)
conv_layer = 'features.17.conv.3'
# Get a pretrained model
model = mobilenet_v2(pretrained=False).eval()

# Hook the corresponding layer in the model
extractor = cams.__dict__[name](model, conv_layer)
extractor = cams.SmoothGradCAMpp(model)

self._test_extractor(extractor, model)

def test_smooth_gradcampp(self):

# Get a pretrained model
model = mobilenet_v2(pretrained=False)
conv_layer = 'features'
input_layer = 'features'
class CAMUtilsTester(unittest.TestCase):

# Hook the corresponding layer in the model
extractor = cams.SmoothGradCAMpp(model, conv_layer, input_layer)
@staticmethod
def _get_custom_module():

self._test_extractor(extractor, model)
mod = nn.Sequential(
nn.Sequential(
nn.Conv2d(3, 8, 3, 1),
nn.ReLU(),
nn.Conv2d(8, 16, 3, 1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1))
),
nn.Flatten(1),
nn.Linear(16, 1)
)
return mod

def test_locate_candidate_layer(self):

# ResNet-18
mod = resnet18().eval()
self.assertEqual(cams.utils.locate_candidate_layer(mod), 'layer4')

# Custom model
mod = self._get_custom_module()

self.assertEqual(cams.utils.locate_candidate_layer(mod), '0.3')
# Check that the model is switched back to its origin mode afterwards
self.assertTrue(mod.training)

def test_locate_linear_layer(self):

# ResNet-18
mod = resnet18().eval()
self.assertEqual(cams.utils.locate_linear_layer(mod), 'fc')

# Custom model
mod = self._get_custom_module()
self.assertEqual(cams.utils.locate_linear_layer(mod), '2')


for cam_extractor in ['CAM', 'ScoreCAM', 'SSCAM', 'ISCAM']:
def do_test(self, cam_extractor=cam_extractor):
self._test_cam(cam_extractor)
self._test_cam_arbitrary_layer(cam_extractor)

setattr(Tester, "test_" + cam_extractor.lower(), do_test)
setattr(CAMCoreTester, "test_" + cam_extractor.lower(), do_test)


for cam_extractor in ['GradCAM', 'GradCAMpp']:
def do_test(self, cam_extractor=cam_extractor):
self._test_gradcam(cam_extractor)
self._test_gradcam_arbitrary_layer(cam_extractor)

setattr(Tester, "test_" + cam_extractor.lower(), do_test)
setattr(CAMCoreTester, "test_" + cam_extractor.lower(), do_test)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchcam import utils


class Tester(unittest.TestCase):
class UtilsTester(unittest.TestCase):
def test_overlay_mask(self):

img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8))
Expand Down
1 change: 1 addition & 0 deletions torchcam/cams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .cam import *
from .gradcam import *
from .utils import *
Loading

0 comments on commit 71c2756

Please sign in to comment.