diff --git a/demo/app.py b/demo/app.py index 77417ef1..e38a2297 100644 --- a/demo/app.py +++ b/demo/app.py @@ -63,7 +63,7 @@ def main(): if cam_method is not None: cam_extractor = cams.__dict__[cam_method]( model, - target_layer=target_layer if len(target_layer) > 0 else None + target_layer=target_layer.split("+") if len(target_layer) > 0 else None ) class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)] @@ -91,7 +91,9 @@ def main(): else: class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1]) # Retrieve the CAM - activation_map = cam_extractor(class_idx, out)[0] + act_maps = cam_extractor(class_idx, out) + # Fuse the CAMs if there are several + activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps) # Plot the raw heatmap fig, ax = plt.subplots() ax.imshow(activation_map.numpy()) diff --git a/test/test_cams_cam.py b/test/test_cams_cam.py index 4ecf7c48..b0a66a8f 100644 --- a/test/test_cams_cam.py +++ b/test/test_cams_cam.py @@ -13,8 +13,8 @@ def test_base_cam_constructor(mock_img_model): model = mobilenet_v2(pretrained=False).eval() # Check that multiple target layers is disabled for base CAM - with pytest.raises(TypeError): - _ = cam.CAM(model, ['classifier.1']) + with pytest.raises(ValueError): + _ = cam.CAM(model, ['classifier.1', 'classifier.2']) # FC layer checks with pytest.raises(TypeError): diff --git a/torchcam/cams/cam.py b/torchcam/cams/cam.py index 1ec228c4..ddb5b758 100644 --- a/torchcam/cams/cam.py +++ b/torchcam/cams/cam.py @@ -50,14 +50,14 @@ class CAM(_CAM): def __init__( self, model: nn.Module, - target_layer: Optional[Union[nn.Module, str]] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, fc_layer: Optional[Union[nn.Module, str]] = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any, ) -> None: - if isinstance(target_layer, list): - raise TypeError("invalid argument type for `target_layer`") + if isinstance(target_layer, list) and len(target_layer) > 1: + raise ValueError("base CAM does not support multiple target layers") super().__init__(model, target_layer, input_shape, **kwargs) @@ -132,7 +132,7 @@ class ScoreCAM(_CAM): def __init__( self, model: nn.Module, - target_layer: Optional[str] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, batch_size: int = 32, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any, @@ -257,7 +257,7 @@ class SSCAM(ScoreCAM): def __init__( self, model: nn.Module, - target_layer: Optional[str] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, batch_size: int = 32, num_samples: int = 35, std: float = 2.0, @@ -346,7 +346,7 @@ class ISCAM(ScoreCAM): def __init__( self, model: nn.Module, - target_layer: Optional[str] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, batch_size: int = 32, num_samples: int = 10, input_shape: Tuple[int, ...] = (3, 224, 224), diff --git a/torchcam/cams/core.py b/torchcam/cams/core.py index 896c1049..8e4afc77 100644 --- a/torchcam/cams/core.py +++ b/torchcam/cams/core.py @@ -29,7 +29,7 @@ class _CAM: def __init__( self, model: nn.Module, - target_layer: Optional[Union[nn.Module, str]] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, input_shape: Tuple[int, ...] = (3, 224, 224), enable_hooks: bool = True, ) -> None: diff --git a/torchcam/cams/gradcam.py b/torchcam/cams/gradcam.py index 22db488e..e8abd493 100644 --- a/torchcam/cams/gradcam.py +++ b/torchcam/cams/gradcam.py @@ -27,7 +27,7 @@ class _GradCAM(_CAM): def __init__( self, model: nn.Module, - target_layer: Optional[Union[nn.Module, str]] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any, ) -> None: @@ -231,7 +231,7 @@ class SmoothGradCAMpp(_GradCAM): def __init__( self, model: nn.Module, - target_layer: Optional[Union[nn.Module, str]] = None, + target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None, num_samples: int = 4, std: float = 0.3, input_shape: Tuple[int, ...] = (3, 224, 224),