diff --git a/grizzly/args.py b/grizzly/args.py index 5c201fbe..ac402562 100644 --- a/grizzly/args.py +++ b/grizzly/args.py @@ -18,8 +18,7 @@ from FTB.ProgramConfiguration import ProgramConfiguration from .common.fuzzmanager import FM_CONFIG -from .common.plugins import scan as scan_plugins -from .common.plugins import scan_target_assets +from .common.plugins import scan_plugins, scan_target_assets from .common.utils import DEFAULT_TIME_LIMIT, TIMEOUT_DELAY, package_version diff --git a/grizzly/common/plugins.py b/grizzly/common/plugins.py index d2a46248..3b2e53c8 100644 --- a/grizzly/common/plugins.py +++ b/grizzly/common/plugins.py @@ -4,9 +4,13 @@ from logging import getLogger from typing import Any, Dict, List, Tuple -from pkg_resources import iter_entry_points +try: + from pkg_resources import iter_entry_points +except ImportError: + from .utils import iter_entry_points # type: ignore -__all__ = ("load", "scan", "PluginLoadError") + +__all__ = ("load_plugin", "scan_plugins", "PluginLoadError") LOG = getLogger(__name__) @@ -16,7 +20,7 @@ class PluginLoadError(Exception): """Raised if loading a plug-in fails""" -def load(name: str, group: str, base_type: type) -> Any: +def load_plugin(name: str, group: str, base_type: type) -> Any: """Load a plug-in. Args: @@ -30,8 +34,8 @@ def load(name: str, group: str, base_type: type) -> Any: assert isinstance(base_type, type) for entry in iter_entry_points(group): if entry.name == name: - LOG.debug("loading %r (%s)", name, base_type.__name__) plugin = entry.load() + LOG.debug("loading %r (%s)", name, base_type.__name__) break else: raise PluginLoadError(f"{name!r} not found in {group!r}") @@ -40,7 +44,7 @@ def load(name: str, group: str, base_type: type) -> Any: return plugin -def scan(group: str) -> List[str]: +def scan_plugins(group: str) -> List[str]: """Scan for installed plug-ins. Args: @@ -49,7 +53,7 @@ def scan(group: str) -> List[str]: Returns: Names of installed entry points. """ - found = [] + found: List[str] = [] LOG.debug("scanning %r", group) for entry in iter_entry_points(group): if entry.name in found: @@ -68,7 +72,7 @@ def scan_target_assets() -> Dict[str, Tuple[str, ...]]: Returns: Name of target and list of supported assets. """ - assets = {} + assets: Dict[str, Tuple[str, ...]] = {} for entry in iter_entry_points("grizzly_targets"): assets[entry.name] = entry.load().SUPPORTED_ASSETS return assets diff --git a/grizzly/common/test_plugins.py b/grizzly/common/test_plugins.py index 87a7ebd3..85209e3a 100644 --- a/grizzly/common/test_plugins.py +++ b/grizzly/common/test_plugins.py @@ -1,11 +1,12 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from pkg_resources import EntryPoint +from importlib.metadata import EntryPoint + from pytest import raises from ..target import Target -from .plugins import PluginLoadError, load, scan, scan_target_assets +from .plugins import PluginLoadError, load_plugin, scan_plugins, scan_target_assets class FakeType1: @@ -19,22 +20,23 @@ class FakeType2: def test_load_01(mocker): """test load() - nothing to load""" mocker.patch( - "grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[] + "grizzly.common.plugins.iter_entry_points", autospec=True, return_value=() ) with raises(PluginLoadError, match="'test-name' not found in 'test-group'"): - load("test-name", "test-group", FakeType1) + load_plugin("test-name", "test-group", FakeType1) def test_load_02(mocker): """test load() - successful load""" - # Note: Mock.name cannot be set via the constructor so spec_set cannot be used entry = mocker.Mock(spec=EntryPoint) entry.name = "test-name" entry.load.return_value = FakeType1 mocker.patch( - "grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[entry] + "grizzly.common.plugins.iter_entry_points", + autospec=True, + return_value=(entry,), ) - assert load("test-name", "test-group", FakeType1) + assert load_plugin("test-name", "test-group", FakeType1) def test_load_03(mocker): @@ -43,18 +45,20 @@ def test_load_03(mocker): entry.name = "test-name" entry.load.return_value = FakeType1 mocker.patch( - "grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[entry] + "grizzly.common.plugins.iter_entry_points", + autospec=True, + return_value=(entry,), ) with raises(PluginLoadError, match="'test-name' doesn't inherit from FakeType2"): - load("test-name", "test-group", FakeType2) + load_plugin("test-name", "test-group", FakeType2) def test_scan_01(mocker): """test scan() - no entries found""" mocker.patch( - "grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[] + "grizzly.common.plugins.iter_entry_points", autospec=True, return_value=() ) - assert not scan("test_group") + assert not scan_plugins("test_group") def test_scan_02(mocker): @@ -64,10 +68,10 @@ def test_scan_02(mocker): mocker.patch( "grizzly.common.plugins.iter_entry_points", autospec=True, - return_value=[entry, entry], + return_value=(entry, entry), ) with raises(PluginLoadError, match="Duplicate entry 'test_entry' in 'test_group'"): - scan("test_group") + scan_plugins("test_group") def test_scan_03(mocker): @@ -77,26 +81,26 @@ def test_scan_03(mocker): mocker.patch( "grizzly.common.plugins.iter_entry_points", autospec=True, - return_value=[entry], + return_value=(entry,), ) - assert "test-name" in scan("test_group") + assert "test-name" in scan_plugins("test_group") def test_scan_target_assets_01(mocker): """test scan_target_assets() - success""" targ1 = mocker.Mock(spec=EntryPoint) targ1.name = "t1" - targ1.load.return_value = mocker.Mock(spec_set=Target, SUPPORTED_ASSETS=None) + targ1.load.return_value = mocker.Mock(spec_set=Target, SUPPORTED_ASSETS=()) targ2 = mocker.Mock(spec=EntryPoint) targ2.name = "t2" targ2.load.return_value = mocker.Mock(spec_set=Target, SUPPORTED_ASSETS=("a", "B")) mocker.patch( "grizzly.common.plugins.iter_entry_points", autospec=True, - return_value=[targ1, targ2], + return_value=(targ1, targ2), ) assets = scan_target_assets() assert "t1" in assets - assert assets["t1"] is None + assert not assets["t1"] assert "t2" in assets assert "B" in assets["t2"] diff --git a/grizzly/common/utils.py b/grizzly/common/utils.py index 5b1dd41b..c6a638d1 100644 --- a/grizzly/common/utils.py +++ b/grizzly/common/utils.py @@ -2,12 +2,13 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from enum import IntEnum, unique -from importlib.metadata import PackageNotFoundError, version +from importlib.metadata import EntryPoint, PackageNotFoundError, entry_points, version from logging import DEBUG, basicConfig, getLogger from os import getenv from pathlib import Path +from sys import version_info from tempfile import gettempdir -from typing import Any, Iterable, Optional, Tuple, Union +from typing import Any, Generator, Iterable, Optional, Tuple, Union __all__ = ( "ConfigError", @@ -17,6 +18,7 @@ "Exit", "grz_tmp", "HARNESS_FILE", + "iter_entry_points", "package_version", "time_limits", "TIMEOUT_DELAY", @@ -105,6 +107,21 @@ def display_time_limits(time_limit: int, timeout: int, no_harness: bool) -> None LOG.warning("TIMEOUT DISABLED, not recommended for automation") +def iter_entry_points(group: str) -> Generator[EntryPoint, None, None]: + """Compatibility wrapper code for importlib.metadata.entry_points() + + Args: + group: See entry_points(). + + Yields: + EntryPoint + """ + # TODO: remove this function when support for Python 3.9 is dropped + assert group + assert version_info >= (3, 10) + yield from entry_points().select(group=group) + + def package_version(name: str, default: str = "unknown") -> str: """Get version of an installed package. diff --git a/grizzly/main.py b/grizzly/main.py index 69507b14..dc88ccd7 100644 --- a/grizzly/main.py +++ b/grizzly/main.py @@ -9,7 +9,7 @@ from sapphire import CertificateBundle, Sapphire from .adapter import Adapter -from .common.plugins import load as load_plugin +from .common.plugins import load_plugin from .common.reporter import ( FailedLaunchReporter, FilesystemReporter, diff --git a/grizzly/reduce/core.py b/grizzly/reduce/core.py index d7e31aa4..50b74235 100644 --- a/grizzly/reduce/core.py +++ b/grizzly/reduce/core.py @@ -18,7 +18,7 @@ from sapphire import CertificateBundle, Sapphire from ..common.fuzzmanager import CrashEntry -from ..common.plugins import load as load_plugin +from ..common.plugins import load_plugin from ..common.reporter import ( FailedLaunchReporter, FilesystemReporter, diff --git a/grizzly/reduce/strategies/__init__.py b/grizzly/reduce/strategies/__init__.py index 85d45a6b..94f4326c 100644 --- a/grizzly/reduce/strategies/__init__.py +++ b/grizzly/reduce/strategies/__init__.py @@ -30,11 +30,14 @@ cast, ) -from pkg_resources import iter_entry_points - from ...common.storage import TestCase from ...common.utils import grz_tmp +try: + from pkg_resources import iter_entry_points +except ImportError: + from ...common.utils import iter_entry_points # type: ignore + LOG = getLogger(__name__) diff --git a/grizzly/replay/replay.py b/grizzly/replay/replay.py index 47ecf2ee..dba150c3 100644 --- a/grizzly/replay/replay.py +++ b/grizzly/replay/replay.py @@ -13,7 +13,7 @@ from sapphire import CertificateBundle, Sapphire, ServerMap -from ..common.plugins import load as load_plugin +from ..common.plugins import load_plugin from ..common.report import Report from ..common.reporter import ( FailedLaunchReporter,