Skip to content

Commit

Permalink
Modernize typing
Browse files Browse the repository at this point in the history
  • Loading branch information
tysmith committed Sep 11, 2024
1 parent 73a9b80 commit 46a41f4
Show file tree
Hide file tree
Showing 37 changed files with 447 additions and 427 deletions.
24 changes: 13 additions & 11 deletions grizzly/adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple, final
from typing import Any, Generator, final

from sapphire import ServerMap

Expand Down Expand Up @@ -54,13 +56,13 @@ def __init__(self, name: str) -> None:
raise AdapterError("name must not be empty")
if len(name.split()) != 1 or name.strip() != name:
raise AdapterError("name must not contain whitespace")
self._harness: Optional[bytes] = None
self.fuzz: Dict[str, Any] = {}
self.monitor: Optional[TargetMonitor] = None
self._harness: bytes | None = None
self.fuzz: dict[str, Any] = {}
self.monitor: TargetMonitor | None = None
self.name = name
self.remaining: Optional[int] = None
self.remaining: int | None = None

def __enter__(self) -> "Adapter":
def __enter__(self) -> Adapter:
return self

def __exit__(self, *exc: Any) -> None:
Expand Down Expand Up @@ -94,7 +96,7 @@ def enable_harness(self, path: Path = HARNESS_FILE) -> None:
assert self._harness, f"empty harness file '{path.resolve()}'"

@final
def get_harness(self) -> Optional[bytes]:
def get_harness(self) -> bytes | None:
"""Get the harness. Used internally by Grizzly.
*** DO NOT OVERRIDE! ***
Expand All @@ -110,7 +112,7 @@ def get_harness(self) -> Optional[bytes]:
@staticmethod
def scan_path(
path: str,
ignore: Tuple[str, ...] = IGNORE_FILES,
ignore: tuple[str, ...] = IGNORE_FILES,
recursive: bool = False,
) -> Generator[str, None, None]:
"""Scan a path and yield the files within it. This is available as
Expand Down Expand Up @@ -149,7 +151,7 @@ def generate(self, testcase: TestCase, server_map: ServerMap) -> None:
None
"""

def on_served(self, testcase: TestCase, served: Tuple[str, ...]) -> None:
def on_served(self, testcase: TestCase, served: tuple[str, ...]) -> None:
"""Optional. Automatically called after a test case is successfully served.
Args:
Expand All @@ -160,7 +162,7 @@ def on_served(self, testcase: TestCase, served: Tuple[str, ...]) -> None:
None
"""

def on_timeout(self, testcase: TestCase, served: Tuple[str, ...]) -> None:
def on_timeout(self, testcase: TestCase, served: tuple[str, ...]) -> None:
"""Optional. Automatically called if timeout occurs while attempting to
serve a test case. By default it calls `self.on_served()`.
Expand All @@ -184,7 +186,7 @@ def pre_launch(self) -> None:
"""

# TODO: update input_path type (str -> Path)
def setup(self, input_path: Optional[str], server_map: ServerMap) -> None:
def setup(self, input_path: str | None, server_map: ServerMap) -> None:
"""Optional. Automatically called once at startup.
Args:
Expand Down
5 changes: 2 additions & 3 deletions grizzly/adapter/no_op_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from typing import Optional
from __future__ import annotations

from sapphire import ServerMap

Expand All @@ -20,7 +19,7 @@ class NoOpAdapter(Adapter):

NAME = "no-op"

def setup(self, input_path: Optional[str], server_map: ServerMap) -> None:
def setup(self, input_path: str | None, server_map: ServerMap) -> None:
"""Generate a static test case that calls `window.close()` when run.
Normally this is done in generate() but since the test is static only
do it once. Use the default harness to allow running multiple test cases
Expand Down
12 changes: 7 additions & 5 deletions grizzly/args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from argparse import (
Action,
ArgumentParser,
Expand All @@ -14,7 +16,7 @@
from pathlib import Path
from platform import system
from types import MappingProxyType
from typing import Iterable, List, Optional
from typing import Iterable

from FTB.ProgramConfiguration import ProgramConfiguration

Expand All @@ -26,18 +28,18 @@
# ref: https://stackoverflow.com/questions/12268602/sort-argparse-help-alphabetically
class SortingHelpFormatter(HelpFormatter):
@staticmethod
def __sort_key(action: Action) -> List[str]:
def __sort_key(action: Action) -> list[str]:
for opt in action.option_strings:
if opt.startswith("--"):
return [opt]
return list(action.option_strings)

def add_usage(
self,
usage: Optional[str],
usage: str | None,
actions: Iterable[Action],
groups: Iterable[_MutuallyExclusiveGroup],
prefix: Optional[str] = None,
prefix: str | None = None,
) -> None:
actions = sorted(actions, key=self.__sort_key)
super().add_usage(usage, actions, groups, prefix)
Expand Down Expand Up @@ -274,7 +276,7 @@ def is_headless() -> bool:
return True
return False

def parse_args(self, argv: Optional[List[str]] = None) -> Namespace:
def parse_args(self, argv: list[str] | None = None) -> Namespace:
args = self.parser.parse_args(argv)
self.sanity_check(args)
return args
Expand Down
14 changes: 8 additions & 6 deletions grizzly/common/bugzilla.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

import binascii
from base64 import b64decode
from logging import getLogger
from os import environ
from pathlib import Path
from shutil import rmtree
from tempfile import mkdtemp
from typing import Any, Generator, List, Optional, Tuple
from typing import Any, Generator
from zipfile import ZipFile

from bugsy import Bug, Bugsy
Expand All @@ -32,7 +34,7 @@ def __init__(self, bug: Bug) -> None:
self._data = Path(mkdtemp(prefix=f"bug{bug.id}-", dir=grz_tmp("bugzilla")))
self._fetch_attachments()

def __enter__(self) -> "BugzillaBug":
def __enter__(self) -> BugzillaBug:
return self

def __exit__(self, *exc: Any) -> None:
Expand Down Expand Up @@ -82,8 +84,8 @@ def _unpack_archives(self) -> None:
# TODO: add support for other archive types

def assets(
self, ignore: Optional[Tuple[str]] = None
) -> Generator[Tuple[str, Path], None, None]:
self, ignore: tuple[str] | None = None
) -> Generator[tuple[str, Path], None, None]:
"""Scan files for assets.
Arguments:
Expand All @@ -110,7 +112,7 @@ def cleanup(self) -> None:
rmtree(self._data)

@classmethod
def load(cls, bug_id: int) -> Optional["BugzillaBug"]:
def load(cls, bug_id: int) -> BugzillaBug | None:
"""Load bug information from a Bugzilla instance.
Arguments:
Expand All @@ -137,7 +139,7 @@ def load(cls, bug_id: int) -> Optional["BugzillaBug"]:
LOG.error("Unable to connect to %r (%s)", bugzilla.bugzilla_url, exc)
return None

def testcases(self) -> List[Path]:
def testcases(self) -> list[Path]:
"""Create a list of potential test cases.
Arguments:
Expand Down
36 changes: 19 additions & 17 deletions grizzly/common/fuzzmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Interface for getting Crash and Bucket data from CrashManager API"""
from __future__ import annotations

import json
from contextlib import contextmanager
from logging import getLogger
from pathlib import Path
from re import search
from shutil import copyfileobj, rmtree
from tempfile import NamedTemporaryFile, mkdtemp
from typing import Any, Dict, Generator, List, Optional, Tuple, cast
from typing import Any, Dict, Generator, cast
from zipfile import BadZipFile, ZipFile

from Collector.Collector import Collector
Expand Down Expand Up @@ -53,10 +55,10 @@ def __init__(self, crash_id: int) -> None:
assert isinstance(crash_id, int)
self._crash_id = crash_id
self._coll = Collector()
self._contents: Optional[List[Path]] = None
self._data: Optional[Dict[str, Any]] = None
self._storage: Optional[Path] = None
self._sig_filename: Optional[Path] = None
self._contents: list[Path] | None = None
self._data: dict[str, Any] | None = None
self._storage: Path | None = None
self._sig_filename: Path | None = None
self._url = (
f"{self._coll.serverProtocol}://{self._coll.serverHost}:"
f"{self._coll.serverPort}/crashmanager/rest/crashes/{crash_id}/"
Expand All @@ -66,7 +68,7 @@ def __init__(self, crash_id: int) -> None:
def crash_id(self) -> int:
return self._crash_id

def __enter__(self) -> "CrashEntry":
def __enter__(self) -> CrashEntry:
return self

def __exit__(self, *exc: Any) -> None:
Expand Down Expand Up @@ -112,7 +114,7 @@ def cleanup(self) -> None:
rmtree(self._sig_filename.parent)

@staticmethod
def _subset(tests: List[Path], subset: List[int]) -> List[Path]:
def _subset(tests: list[Path], subset: list[int]) -> list[Path]:
"""Select a subset of tests directories. Subset values are sanitized to
avoid raising.
Expand All @@ -133,8 +135,8 @@ def _subset(tests: List[Path], subset: List[int]) -> List[Path]:
return [tests[i] for i in sorted(keep)]

def testcases(
self, subset: Optional[List[int]] = None, ext: Optional[str] = None
) -> List[Path]:
self, subset: list[int] | None = None, ext: str | None = None
) -> list[Path]:
"""Download the testcase data from CrashManager.
Arguments:
Expand Down Expand Up @@ -270,19 +272,19 @@ def __init__(self, bucket_id: int) -> None:
"""
assert isinstance(bucket_id, int)
self._bucket_id = bucket_id
self._sig_filename: Optional[Path] = None
self._sig_filename: Path | None = None
self._coll = Collector()
self._url = (
f"{self._coll.serverProtocol}://{self._coll.serverHost}:"
f"{self._coll.serverPort}/crashmanager/rest/buckets/{bucket_id}/"
)
self._data: Optional[Dict[str, Any]] = None
self._data: dict[str, Any] | None = None

@property
def bucket_id(self) -> int:
return self._bucket_id

def __enter__(self) -> "Bucket":
def __enter__(self) -> Bucket:
return self

def __exit__(self, *exc: Any) -> None:
Expand Down Expand Up @@ -317,7 +319,7 @@ def cleanup(self) -> None:
rmtree(self._sig_filename.parent)

def iter_crashes(
self, quality_filter: Optional[int] = None
self, quality_filter: int | None = None
) -> Generator[CrashEntry, None, None]:
"""Fetch all crash IDs for this FuzzManager bucket.
Only crashes with testcases are returned.
Expand All @@ -330,8 +332,8 @@ def iter_crashes(
"""

def _get_results(
endpoint: str, params: Optional[Dict[str, str]] = None
) -> Generator[Dict[str, Any], None, None]:
endpoint: str, params: dict[str, str] | None = None
) -> Generator[dict[str, Any], None, None]:
"""
Function to get paginated results from FuzzManager
Expand All @@ -349,7 +351,7 @@ def _get_results(
f"{self._coll.serverPort}/crashmanager/rest/{endpoint}/"
)

response: Dict[str, Any] = self._coll.get(url, params=params).json()
response: dict[str, Any] = self._coll.get(url, params=params).json()

while True:
LOG.debug(
Expand Down Expand Up @@ -433,7 +435,7 @@ def signature_path(self) -> Path:
@contextmanager
def load_fm_data(
crash_id: int, load_bucket: bool = False
) -> Generator[Tuple[CrashEntry, Optional[Bucket]], None, None]:
) -> Generator[tuple[CrashEntry, Bucket | None], None, None]:
"""Load CrashEntry including Bucket from FuzzManager.
Arguments:
Expand Down
10 changes: 6 additions & 4 deletions grizzly/common/iomanager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, List, Optional
from __future__ import annotations

from typing import Any

from sapphire.server_map import ServerMap

Expand All @@ -25,13 +27,13 @@ def __init__(self, report_size: int = 1) -> None:
assert report_size > 0
self.server_map = ServerMap()
# tests will be ordered oldest to newest
self.tests: List[TestCase] = []
self.tests: list[TestCase] = []
# total number of test cases generated
self._generated = 0
self._report_size = report_size
self._test: Optional[TestCase] = None
self._test: TestCase | None = None

def __enter__(self) -> "IOManager":
def __enter__(self) -> IOManager:
return self

def __exit__(self, *exc: Any) -> None:
Expand Down
12 changes: 7 additions & 5 deletions grizzly/common/plugins.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from logging import getLogger
from typing import Any, Dict, List, Tuple
from typing import Any

try:
from pkg_resources import iter_entry_points
Expand Down Expand Up @@ -44,7 +46,7 @@ def load_plugin(name: str, group: str, base_type: type) -> Any:
return plugin


def scan_plugins(group: str) -> List[str]:
def scan_plugins(group: str) -> list[str]:
"""Scan for installed plug-ins.
Args:
Expand All @@ -53,7 +55,7 @@ def scan_plugins(group: str) -> List[str]:
Returns:
Names of installed entry points.
"""
found: List[str] = []
found: list[str] = []
LOG.debug("scanning %r", group)
for entry in iter_entry_points(group):
if entry.name in found:
Expand All @@ -63,7 +65,7 @@ def scan_plugins(group: str) -> List[str]:
return found


def scan_target_assets() -> Dict[str, Tuple[str, ...]]:
def scan_target_assets() -> dict[str, tuple[str, ...]]:
"""Scan targets and load collection of supported assets (minimal sanity checking).
Args:
Expand All @@ -72,7 +74,7 @@ def scan_target_assets() -> Dict[str, Tuple[str, ...]]:
Returns:
Name of target and list of supported assets.
"""
assets: Dict[str, Tuple[str, ...]] = {}
assets: dict[str, tuple[str, ...]] = {}
for entry in iter_entry_points("grizzly_targets"):
assets[entry.name] = entry.load().SUPPORTED_ASSETS
return assets
Loading

0 comments on commit 46a41f4

Please sign in to comment.