Skip to content

Commit

Permalink
feat: Added possibility to retrieve multiple CAMs in demo (#105)
Browse files Browse the repository at this point in the history
* feat: Added possibility to retrieve multiple CAMs in demo

* style: Fixed typing

* fix: Harmonized arg type support for CAM

* test: Updated unittest

* test: Fixed typo

* test: Fixed typo
  • Loading branch information
frgfm authored Oct 31, 2021
1 parent 0127071 commit 8abb3ea
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 13 deletions.
6 changes: 4 additions & 2 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def main():
if cam_method is not None:
cam_extractor = cams.__dict__[cam_method](
model,
target_layer=target_layer if len(target_layer) > 0 else None
target_layer=target_layer.split("+") if len(target_layer) > 0 else None
)

class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
Expand Down Expand Up @@ -91,7 +91,9 @@ def main():
else:
class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
# Retrieve the CAM
activation_map = cam_extractor(class_idx, out)[0]
act_maps = cam_extractor(class_idx, out)
# Fuse the CAMs if there are several
activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps)
# Plot the raw heatmap
fig, ax = plt.subplots()
ax.imshow(activation_map.numpy())
Expand Down
4 changes: 2 additions & 2 deletions test/test_cams_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
def test_base_cam_constructor(mock_img_model):
model = mobilenet_v2(pretrained=False).eval()
# Check that multiple target layers is disabled for base CAM
with pytest.raises(TypeError):
_ = cam.CAM(model, ['classifier.1'])
with pytest.raises(ValueError):
_ = cam.CAM(model, ['classifier.1', 'classifier.2'])

# FC layer checks
with pytest.raises(TypeError):
Expand Down
12 changes: 6 additions & 6 deletions torchcam/cams/cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class CAM(_CAM):
def __init__(
self,
model: nn.Module,
target_layer: Optional[Union[nn.Module, str]] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
fc_layer: Optional[Union[nn.Module, str]] = None,
input_shape: Tuple[int, ...] = (3, 224, 224),
**kwargs: Any,
) -> None:

if isinstance(target_layer, list):
raise TypeError("invalid argument type for `target_layer`")
if isinstance(target_layer, list) and len(target_layer) > 1:
raise ValueError("base CAM does not support multiple target layers")

super().__init__(model, target_layer, input_shape, **kwargs)

Expand Down Expand Up @@ -132,7 +132,7 @@ class ScoreCAM(_CAM):
def __init__(
self,
model: nn.Module,
target_layer: Optional[str] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
batch_size: int = 32,
input_shape: Tuple[int, ...] = (3, 224, 224),
**kwargs: Any,
Expand Down Expand Up @@ -257,7 +257,7 @@ class SSCAM(ScoreCAM):
def __init__(
self,
model: nn.Module,
target_layer: Optional[str] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
batch_size: int = 32,
num_samples: int = 35,
std: float = 2.0,
Expand Down Expand Up @@ -346,7 +346,7 @@ class ISCAM(ScoreCAM):
def __init__(
self,
model: nn.Module,
target_layer: Optional[str] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
batch_size: int = 32,
num_samples: int = 10,
input_shape: Tuple[int, ...] = (3, 224, 224),
Expand Down
2 changes: 1 addition & 1 deletion torchcam/cams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class _CAM:
def __init__(
self,
model: nn.Module,
target_layer: Optional[Union[nn.Module, str]] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
input_shape: Tuple[int, ...] = (3, 224, 224),
enable_hooks: bool = True,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions torchcam/cams/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class _GradCAM(_CAM):
def __init__(
self,
model: nn.Module,
target_layer: Optional[Union[nn.Module, str]] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
input_shape: Tuple[int, ...] = (3, 224, 224),
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -231,7 +231,7 @@ class SmoothGradCAMpp(_GradCAM):
def __init__(
self,
model: nn.Module,
target_layer: Optional[Union[nn.Module, str]] = None,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
num_samples: int = 4,
std: float = 0.3,
input_shape: Tuple[int, ...] = (3, 224, 224),
Expand Down

0 comments on commit 8abb3ea

Please sign in to comment.