diff --git a/tests/test_methods_utils.py b/tests/test_methods_utils.py index 4792bf9..402477f 100644 --- a/tests/test_methods_utils.py +++ b/tests/test_methods_utils.py @@ -1,4 +1,4 @@ -from torchvision.models import resnet18 +from torchvision.models import resnet18, mobilenet_v3_large from torchcam.methods import _utils @@ -8,6 +8,10 @@ def test_locate_candidate_layer(mock_img_model): mod = resnet18().eval() assert _utils.locate_candidate_layer(mod) == "layer4" + # Mobilenet V3 Large + mod = mobilenet_v3_large().eval() + assert _utils.locate_candidate_layer(mod) == "features" + # Custom model mod = mock_img_model.train() diff --git a/torchcam/methods/_utils.py b/torchcam/methods/_utils.py index 6917ffb..a5697cf 100644 --- a/torchcam/methods/_utils.py +++ b/torchcam/methods/_utils.py @@ -51,12 +51,11 @@ def _record_output_shape(module: nn.Module, input: Tensor, output: Tensor, name: # Check output shapes candidate_layer = None - for layer_name, output_shape in output_shapes: + for layer_name, output_shape in output_shapes[::-1]: # Stop before flattening or global pooling - if len(output_shape) != (len(input_shape) + 1) or all(v == 1 for v in output_shape[2:]): - break - else: + if len(output_shape) == (len(input_shape) + 1) and any(v != 1 for v in output_shape[2:]): candidate_layer = layer_name + break return candidate_layer