From 443b01d1bb0669c9e36ce7780cac036761b4cdf9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 26 Jul 2024 04:35:40 +0000 Subject: [PATCH 1/5] build(deps): bump ruff from 0.5.2 to 0.5.5 Bumps [ruff](https://github.com/astral-sh/ruff) from 0.5.2 to 0.5.5. - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/0.5.2...0.5.5) --- updated-dependencies: - dependency-name: ruff dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 50aa246..310ebf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ test = [ "pytest-pretty>=1.0.0,<2.0.0", ] quality = [ - "ruff==0.5.2", + "ruff==0.5.5", "mypy==1.10.0", "types-Pillow", "pre-commit>=3.0.0,<4.0.0", @@ -77,7 +77,7 @@ dev = [ "pytest-cov>=4.0.0,<5.0.0", "pytest-pretty>=1.0.0,<2.0.0", # style - "ruff==0.5.2", + "ruff==0.5.5", "mypy==1.10.0", "types-Pillow", "pre-commit>=3.0.0,<4.0.0", From 0cf4bf387776de4caff8a8349c103b6d433b2103 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Mon, 29 Jul 2024 22:18:24 +0200 Subject: [PATCH 2/5] style(pre-commit): bump ruff to 0.5.5 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 473c906..fca9c0d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.5.2' + rev: 'v0.5.5' hooks: - id: ruff args: From 330f2279b8c1243465c39f2e5e6cd4bca6c0d33e Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:12:08 +0200 Subject: [PATCH 3/5] style(pyproject): update ruff rule selection --- pyproject.toml | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 310ebf7..6ef93ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,37 +116,43 @@ preview = true [tool.ruff.lint] select = [ + "F", # pyflakes "E", # pycodestyle errors "W", # pycodestyle warnings - "D101", "D103", # pydocstyle missing docstring in public function/class - "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", # pydocstyle - "F", # pyflakes "I", # isort - "C4", # flake8-comprehensions - "B", # flake8-bugbear - "CPY001", # flake8-copyright - "ISC", # flake8-implicit-str-concat - "PYI", # flake8-pyi - "NPY", # numpy - "PERF", # perflint - "RUF", # ruff specific - "PTH", # flake8-use-pathlib - "S", # flake8-bandit "N", # pep8-naming - "T10", # flake8-debugger - "T20", # flake8-print - "PT", # flake8-pytest-style - "LOG", # flake8-logging - "SIM", # flake8-simplify + "D101", "D103", # pydocstyle missing docstring in public function/class + "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", # pydocstyle "YTT", # flake8-2020 "ANN", # flake8-annotations "ASYNC", # flake8-async + "S", # flake8-bandit "BLE", # flake8-blind-except + "B", # flake8-bugbear "A", # flake8-builtins + "COM", # flake8-commas + "CPY", # flake8-copyright + "C4", # flake8-comprehensions + "T10", # flake8-debugger + "ISC", # flake8-implicit-str-concat "ICN", # flake8-import-conventions + "LOG", # flake8-logging "PIE", # flake8-pie + "T20", # flake8-print + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RET", # flake8-return + "SLF", # flake8-self + "SIM", # flake8-simplify "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "PERF", # perflint + "NPY", # numpy + "FAST", # fastapi "FURB", # refurb + "RUF", # ruff specific + "N", # pep8-naming ] ignore = [ "E501", # line too long, handled by black @@ -179,7 +185,7 @@ known-third-party = ["torch", "torchvision"] "scripts/**.py" = ["D", "T201", "N812", "S101", "ANN"] ".github/**.py" = ["D", "T201", "S602", "S101", "ANN"] "docs/**.py" = ["E402", "D103", "ANN", "A001", "ARG001"] -"tests/**.py" = ["D103", "CPY001", "S101", "PT011", "ANN"] +"tests/**.py" = ["D103", "CPY001", "S101", "PT011", "ANN", "SLF001"] "demo/**.py" = ["D103", "ANN"] "setup.py" = ["T201"] From 7a1fb43301f696961f540e4bb76f8d4fab47fae9 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:24:15 +0200 Subject: [PATCH 4/5] style: fix lint --- .github/collect_env.py | 9 ++++----- .github/verify_labels.py | 4 +--- scripts/cam_example.py | 4 ++-- torchcam/methods/activation.py | 7 +++---- torchcam/methods/core.py | 25 ++++++++++++++++--------- torchcam/metrics.py | 4 ++-- 6 files changed, 28 insertions(+), 25 deletions(-) diff --git a/.github/collect_env.py b/.github/collect_env.py index 26ab4b7..57b262e 100644 --- a/.github/collect_env.py +++ b/.github/collect_env.py @@ -159,14 +159,13 @@ def get_nvidia_smi(): def get_platform(): if sys.platform.startswith("linux"): return "linux" - elif sys.platform.startswith("win32"): + if sys.platform.startswith("win32"): return "win32" - elif sys.platform.startswith("cygwin"): + if sys.platform.startswith("cygwin"): return "cygwin" - elif sys.platform.startswith("darwin"): + if sys.platform.startswith("darwin"): return "darwin" - else: - return sys.platform + return sys.platform def get_mac_version(run_lambda): diff --git a/.github/verify_labels.py b/.github/verify_labels.py index b4474ef..ff3c3d3 100644 --- a/.github/verify_labels.py +++ b/.github/verify_labels.py @@ -78,9 +78,7 @@ def parse_args(): ) parser.add_argument("pr", type=int, help="PR number") - args = parser.parse_args() - - return args + return parser.parse_args() if __name__ == "__main__": diff --git a/scripts/cam_example.py b/scripts/cam_example.py index b395267..5ea8ddd 100644 --- a/scripts/cam_example.py +++ b/scripts/cam_example.py @@ -74,7 +74,7 @@ def main(args): ax.set_title("Input", size=8) for idx, extractor in zip(range(1, len(cam_extractors) + 1), cam_extractors): - extractor._hooks_enabled = True + extractor.enable_hooks() model.zero_grad() scores = model(img_tensor.unsqueeze(0)) @@ -85,8 +85,8 @@ def main(args): activation_map = extractor(class_idx, scores)[0].squeeze(0).cpu() # Clean data + extractor.disable_hooks() extractor.remove_hooks() - extractor._hooks_enabled = False # Convert it to PIL image # The indexing below means first image in batch heatmap = to_pil_image(activation_map, mode="F") diff --git a/torchcam/methods/activation.py b/torchcam/methods/activation.py index 01a9371..6373853 100644 --- a/torchcam/methods/activation.py +++ b/torchcam/methods/activation.py @@ -90,8 +90,7 @@ def _get_weights( # Take the FC weights of the target class if isinstance(class_idx, int): return [self._fc_weights[class_idx, :].unsqueeze(0)] - else: - return [self._fc_weights[class_idx, :]] + return [self._fc_weights[class_idx, :]] class ScoreCAM(_CAM): @@ -214,7 +213,7 @@ def _get_weights( ] # Disable hook updates - self._hooks_enabled = False + self.disable_hooks() # Switch to eval origin_mode = self.model.training self.model.eval() @@ -222,7 +221,7 @@ def _get_weights( weights: List[Tensor] = self._get_score_weights(upsampled_a, class_idx) # Reenable hook updates - self._hooks_enabled = True + self.enable_hooks() # Put back the model in the correct mode self.model.training = origin_mode diff --git a/torchcam/methods/core.py b/torchcam/methods/core.py index 5c1ad60..5322cf1 100644 --- a/torchcam/methods/core.py +++ b/torchcam/methods/core.py @@ -78,6 +78,14 @@ def __init__( # Model output is used by the extractor self._score_used = False + def enable_hooks(self) -> None: + """Enable hooks.""" + self._hooks_enabled = True + + def disable_hooks(self) -> None: + """Disable hooks.""" + self._hooks_enabled = False + def __enter__(self) -> "_CAM": return self @@ -236,17 +244,16 @@ def fuse_cams(cls, cams: List[Tensor], target_shape: Optional[Tuple[int, int]] = if len(cams) == 0: raise ValueError("argument `cams` cannot be an empty list") - elif len(cams) == 1: + if len(cams) == 1: return cams[0] + # Resize to the biggest CAM if no value was provided for `target_shape` + if isinstance(target_shape, tuple): + _shape = target_shape else: - # Resize to the biggest CAM if no value was provided for `target_shape` - if isinstance(target_shape, tuple): - _shape = target_shape - else: - _shape = tuple(map(max, zip(*[tuple(cam.shape[1:]) for cam in cams]))) - # Scale cams - scaled_cams = cls._scale_cams(cams) - return cls._fuse_cams(scaled_cams, _shape) + _shape = tuple(map(max, zip(*[tuple(cam.shape[1:]) for cam in cams]))) + # Scale cams + scaled_cams = cls._scale_cams(cams) + return cls._fuse_cams(scaled_cams, _shape) @staticmethod def _scale_cams(cams: List[Tensor]) -> List[Tensor]: diff --git a/torchcam/metrics.py b/torchcam/metrics.py index 8f2e6a2..2335c87 100644 --- a/torchcam/metrics.py +++ b/torchcam/metrics.py @@ -97,7 +97,7 @@ def update( cams = self.cam_extractor(preds.cpu().numpy().tolist(), probs) cam = self.cam_extractor.fuse_cams(cams) probs = probs.gather(1, preds.unsqueeze(1)).squeeze(1) - self.cam_extractor._hooks_enabled = False + self.cam_extractor.disable_hooks() # Safeguard: replace NaNs cam[torch.isnan(cam)] = 0 # Resize the CAM @@ -116,7 +116,7 @@ def update( # Increase increase = probs < masked_probs - self.cam_extractor._hooks_enabled = True + self.cam_extractor.enable_hooks() self.drop += drop.sum().item() self.increase += increase.sum().item() From 93fa0729bd49261370268ccdf190b3086f785703 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:31:11 +0200 Subject: [PATCH 5/5] ci(style): update ruff installation --- .github/workflows/style.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index a0e0398..33b0424 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -22,7 +22,7 @@ jobs: - name: Run ruff run: | python -m pip install --upgrade uv - uv pip install --system ruff==0.3.0 + uv pip install --system --upgrade -e ".[quality]" ruff --version ruff check --diff . @@ -62,7 +62,7 @@ jobs: - name: Run ruff run: | python -m pip install --upgrade uv - uv pip install --system ruff==0.3.0 + uv pip install --system --upgrade -e ".[quality]" ruff --version ruff format --check --diff .