From 521b4f9eb168f262e788b4299412cf0959f83f03 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Sun, 27 Dec 2020 02:42:03 +0100 Subject: [PATCH] docs: Fixed example docstring and unittests (#33) * docs: Fixed example of SmoothGradCAMpp * test: Fixed unittests * test: Speeded up unittests * test: Fixed base Cam unittest * test: Optimized speed for Score CAM family * test: Optimized testing speed for SSCAM and ISCAM --- test/test_cams.py | 28 +++++++++++----------------- torchcam/cams/gradcam.py | 2 +- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/test/test_cams.py b/test/test_cams.py index 6655ade7..75087a16 100644 --- a/test/test_cams.py +++ b/test/test_cams.py @@ -11,16 +11,6 @@ from torchcam import cams -def _forward(model, input_tensor): - if model.training: - scores = model(input_tensor) - else: - with torch.no_grad(): - scores = model(input_tensor) - - return scores - - class CAMCoreTester(unittest.TestCase): def _verify_cam(self, cam): # Simple verifications @@ -50,11 +40,11 @@ def _test_extractor(self, extractor, model): img_tensor = self._get_img_tensor() # Check that a batch of 2 cannot be accepted - _ = _forward(model, torch.stack((img_tensor, img_tensor))) + _ = model(torch.stack((img_tensor, img_tensor))) self.assertRaises(ValueError, extractor, 0) # Correct forward - scores = _forward(model, img_tensor.unsqueeze(0)) + scores = model(img_tensor.unsqueeze(0)) # Check incorrect class index self.assertRaises(ValueError, extractor, -1) @@ -68,11 +58,15 @@ def _test_extractor(self, extractor, model): def _test_cam(self, name): # Get a pretrained model - model = resnet18(pretrained=False).eval() - conv_layer = 'layer4.1.relu' + model = mobilenet_v2(pretrained=False).eval() + conv_layer = None if name == "CAM" else 'features.16.conv.3' + kwargs = {} + # Speed up testing by reducing the number of samples + if name in ['SSCAM', 'ISCAM']: + kwargs['num_samples'] = 4 # Hook the corresponding layer in the model - extractor = cams.__dict__[name](model, conv_layer) + extractor = cams.__dict__[name](model, conv_layer, **kwargs) with torch.no_grad(): self._test_extractor(extractor, model) @@ -80,8 +74,8 @@ def _test_cam(self, name): def _test_gradcam(self, name): # Get a pretrained model - model = mobilenet_v2(pretrained=False) - conv_layer = 'features.17.conv.3' + model = mobilenet_v2(pretrained=False).eval() + conv_layer = 'features.18.0' # Hook the corresponding layer in the model extractor = cams.__dict__[name](model, conv_layer) diff --git a/torchcam/cams/gradcam.py b/torchcam/cams/gradcam.py index 7bf44178..0a7ec81e 100644 --- a/torchcam/cams/gradcam.py +++ b/torchcam/cams/gradcam.py @@ -195,7 +195,7 @@ class SmoothGradCAMpp(_GradCAM): >>> from torchvision.models import resnet18 >>> from torchcam.cams import SmoothGradCAMpp >>> model = resnet18(pretrained=True).eval() - >>> cam = SmoothGradCAMpp(model, 'layer4', 'conv1') + >>> cam = SmoothGradCAMpp(model, 'layer4') >>> scores = model(input_tensor) >>> cam(class_idx=100)