Skip to content

Commit

Permalink
refactor: Improved target resolution (#174)
Browse files Browse the repository at this point in the history
* feat: Improved target resolution

* test: Added dedicated unittest
  • Loading branch information
frgfm authored Aug 1, 2022
1 parent edda2b0 commit 7e1a328
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
6 changes: 5 additions & 1 deletion tests/test_methods_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torchvision.models import resnet18
from torchvision.models import resnet18, mobilenet_v3_large

from torchcam.methods import _utils

Expand All @@ -8,6 +8,10 @@ def test_locate_candidate_layer(mock_img_model):
mod = resnet18().eval()
assert _utils.locate_candidate_layer(mod) == "layer4"

# Mobilenet V3 Large
mod = mobilenet_v3_large().eval()
assert _utils.locate_candidate_layer(mod) == "features"

# Custom model
mod = mock_img_model.train()

Expand Down
7 changes: 3 additions & 4 deletions torchcam/methods/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ def _record_output_shape(module: nn.Module, input: Tensor, output: Tensor, name:

# Check output shapes
candidate_layer = None
for layer_name, output_shape in output_shapes:
for layer_name, output_shape in output_shapes[::-1]:
# Stop before flattening or global pooling
if len(output_shape) != (len(input_shape) + 1) or all(v == 1 for v in output_shape[2:]):
break
else:
if len(output_shape) == (len(input_shape) + 1) and any(v != 1 for v in output_shape[2:]):
candidate_layer = layer_name
break

return candidate_layer

Expand Down

0 comments on commit 7e1a328

Please sign in to comment.