From 72d9e33640580f3d15466973b46d687086919513 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Thu, 19 Oct 2023 16:49:17 +0200 Subject: [PATCH] feat: Removes torchvision warnings --- tests/test_methods_activation.py | 4 ++-- tests/test_methods_gradient.py | 4 ++-- tests/test_metrics.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_methods_activation.py b/tests/test_methods_activation.py index 56841da..76f7f7e 100644 --- a/tests/test_methods_activation.py +++ b/tests/test_methods_activation.py @@ -6,7 +6,7 @@ def test_base_cam_constructor(mock_img_model): - model = mobilenet_v2(pretrained=False).eval() + model = mobilenet_v2(weights=None).eval() for p in model.parameters(): p.requires_grad_(False) # Check that multiple target layers is disabled for base CAM @@ -39,7 +39,7 @@ def _verify_cam(activation_map, output_size): ], ) def test_img_cams(cam_name, target_layer, fc_layer, num_samples, output_size, batch_size, mock_img_tensor): - model = mobilenet_v2(pretrained=False).eval() + model = mobilenet_v2(weights=None).eval() for p in model.parameters(): p.requires_grad_(False) kwargs = {} diff --git a/tests/test_methods_gradient.py b/tests/test_methods_gradient.py index 0def8eb..aa54a2c 100644 --- a/tests/test_methods_gradient.py +++ b/tests/test_methods_gradient.py @@ -26,7 +26,7 @@ def _verify_cam(activation_map, output_size): ], ) def test_img_cams(cam_name, target_layer, output_size, batch_size, mock_img_tensor): - model = mobilenet_v2(pretrained=False).eval() + model = mobilenet_v2(weights=None).eval() for p in model.parameters(): p.requires_grad_(False) @@ -79,7 +79,7 @@ def test_video_cams(cam_name, target_layer, output_size, mock_video_model, mock_ def test_smoothgradcampp_repr(): - model = mobilenet_v2(pretrained=False).eval() + model = mobilenet_v2(weights=None).eval() # Hook the corresponding layer in the model with gradient.SmoothGradCAMpp(model, "features.18.0") as extractor: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 8c5230d..d3e2f30 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -8,7 +8,7 @@ def test_classification_metric(): - model = mobilenet_v3_small(pretrained=False) + model = mobilenet_v3_small(weights=None) with LayerCAM(model, "features.12") as extractor: metric = metrics.ClassificationMetric(extractor, partial(torch.softmax, dim=-1))