Skip to content

Commit

Permalink
Followup changes to fix ruff & pyright warnings (#203)
Browse files Browse the repository at this point in the history
* Annotate error_codes with Mapping instead of dict to silence warnings about mutable classvar, Write __hash__ for Statement
  • Loading branch information
jakkdl authored Feb 22, 2024
1 parent 953a945 commit bae0a63
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 56 deletions.
10 changes: 7 additions & 3 deletions flake8_trio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,22 @@ class Statement(NamedTuple):
lineno: int
col_offset: int = -1

# pyright is unhappy about defining __eq__ but not __hash__ .. which it should
# but it works :tm: and needs changing in a couple places to avoid it.
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Statement)
and self[:2] == other[:2]
and self.name == other.name
and self.lineno == other.lineno
and (
self.col_offset == other.col_offset
or -1 in (self.col_offset, other.col_offset)
)
)

# Objects that are equal needs to have the same hash, so we don't hash on
# `col_offset` since it's a "wildcard" value
def __hash__(self) -> int:
return hash((self.name, self.lineno))


class Error:
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .visitors.visitor_utility import NoqaHandler

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping

from libcst import Module

Expand All @@ -46,7 +46,7 @@ def __init__(self, options: Options):
super().__init__()
self.state = SharedState(options)

def selected(self, error_codes: dict[str, str]) -> bool:
def selected(self, error_codes: Mapping[str, str]) -> bool:
enabled_or_autofix = (
self.state.options.enabled_codes | self.state.options.autofix_codes
)
Expand Down
8 changes: 4 additions & 4 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import ast
from abc import ABC
from typing import TYPE_CHECKING, Any, ClassVar, Union
from typing import TYPE_CHECKING, Any, Union

import libcst as cst
from libcst.metadata import PositionProvider

from ..base import Error, Statement

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping

from ..runner import SharedState

Expand All @@ -23,7 +23,7 @@

class Flake8TrioVisitor(ast.NodeVisitor, ABC):
# abstract attribute by not providing a value
error_codes: ClassVar[dict[str, str]]
error_codes: Mapping[str, str]

def __init__(self, shared_state: SharedState):
super().__init__()
Expand Down Expand Up @@ -158,7 +158,7 @@ def add_library(self, name: str) -> None:

class Flake8TrioVisitor_cst(cst.CSTTransformer, ABC):
# abstract attribute by not providing a value
error_codes: dict[str, str]
error_codes: Mapping[str, str]
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, shared_state: SharedState):
Expand Down
7 changes: 5 additions & 2 deletions flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

import libcst as cst
import libcst.matchers as m
Expand All @@ -21,10 +21,13 @@
with_has_call,
)

if TYPE_CHECKING:
from collections.abc import Mapping


@error_class_cst
class Visitor100_libcst(Flake8TrioVisitor_cst):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO100": (
"{0}.{1} context contains no checkpoints, remove the context or add"
" `await {0}.lowlevel.checkpoint()`."
Expand Down
4 changes: 3 additions & 1 deletion flake8_trio/visitors/visitor101.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
)

if TYPE_CHECKING:
from collections.abc import Mapping

import libcst as cst


@error_class_cst
class Visitor101(Flake8TrioVisitor_cst):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO101": (
"`yield` inside a nursery or cancel scope is only safe when implementing "
"a context manager - otherwise, it breaks exception handling."
Expand Down
7 changes: 5 additions & 2 deletions flake8_trio/visitors/visitor102.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
from __future__ import annotations

import ast
from typing import Any
from typing import TYPE_CHECKING, Any

from ..base import Statement
from .flake8triovisitor import Flake8TrioVisitor
from .helpers import cancel_scope_names, critical_except, error_class, get_matching_call

if TYPE_CHECKING:
from collections.abc import Mapping


@error_class
class Visitor102(Flake8TrioVisitor):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO102": (
"await inside {0.name} on line {0.lineno} must have shielded cancel "
"scope with a timeout."
Expand Down
23 changes: 14 additions & 9 deletions flake8_trio/visitors/visitor103_104.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from __future__ import annotations

import ast
from typing import Any
from typing import TYPE_CHECKING, Any

from .flake8triovisitor import Flake8TrioVisitor
from .helpers import critical_except, error_class, iter_guaranteed_once

if TYPE_CHECKING:
from collections.abc import Mapping

_trio103_common_msg = "{} block with a code path that doesn't re-raise the error."
_suggestion = " Consider adding an `except {}: raise` before this exception handler."
_suggestion_dict: dict[tuple[str, ...], str] = {
Expand All @@ -22,17 +25,19 @@
}
_suggestion_dict[("anyio", "trio")] = "[" + "|".join(_suggestion_dict.values()) + "]"

_error_codes = {
"TRIO103": _trio103_common_msg,
"TRIO104": "Cancelled (and therefore BaseException) must be re-raised.",
}
for poss_library in _suggestion_dict:
_error_codes[f"TRIO103_{'_'.join(poss_library)}"] = (
_trio103_common_msg + _suggestion.format(_suggestion_dict[poss_library])
)


@error_class
class Visitor103_104(Flake8TrioVisitor):
error_codes = {
"TRIO103": _trio103_common_msg,
"TRIO104": "Cancelled (and therefore BaseException) must be re-raised.",
}
for poss_library in _suggestion_dict:
error_codes[f"TRIO103_{'_'.join(poss_library)}"] = (
_trio103_common_msg + _suggestion.format(_suggestion_dict[poss_library])
)
error_codes: Mapping[str, str] = _error_codes

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
Expand Down
8 changes: 6 additions & 2 deletions flake8_trio/visitors/visitor105.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
from __future__ import annotations

import ast
from typing import Any
from typing import TYPE_CHECKING, Any

from .flake8triovisitor import Flake8TrioVisitor
from .helpers import error_class

if TYPE_CHECKING:
from collections.abc import Mapping


# used in 105
trio_async_funcs = (
"trio.aclose_forcefully",
Expand Down Expand Up @@ -39,7 +43,7 @@

@error_class
class Visitor105(Flake8TrioVisitor):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO105": "{0} async {1} must be immediately awaited.",
}

Expand Down
7 changes: 5 additions & 2 deletions flake8_trio/visitors/visitor111.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
from __future__ import annotations

import ast
from typing import Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple

from .flake8triovisitor import Flake8TrioVisitor
from .helpers import error_class, get_matching_call

if TYPE_CHECKING:
from collections.abc import Mapping


@error_class
class Visitor111(Flake8TrioVisitor):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO111": (
"variable {2} is usable within the context manager on line {0}, but that "
"will close before nursery opened on line {1} - this is usually a bug. "
Expand Down
6 changes: 5 additions & 1 deletion flake8_trio/visitors/visitor118.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@

import ast
import re
from typing import TYPE_CHECKING

from .flake8triovisitor import Flake8TrioVisitor
from .helpers import error_class

if TYPE_CHECKING:
from collections.abc import Mapping


@error_class
class Visitor118(Flake8TrioVisitor):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO118": (
"Don't assign the value of `anyio.get_cancelled_exc_class()` to a variable,"
" since that breaks linter checks and multi-backend programs."
Expand Down
19 changes: 11 additions & 8 deletions flake8_trio/visitors/visitor2xx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@

import ast
import re
from typing import Any
from typing import TYPE_CHECKING, Any

from .flake8triovisitor import Flake8TrioVisitor
from .helpers import error_class, fnmatch_qualified_name, get_matching_call

if TYPE_CHECKING:
from collections.abc import Mapping


@error_class
class Visitor200(Flake8TrioVisitor):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO200": (
"User-configured blocking sync call {0} in async function, consider "
"replacing with {1}."
Expand Down Expand Up @@ -55,7 +58,7 @@ def visit_blocking_call(self, node: ast.Call):

@error_class
class Visitor21X(Visitor200):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO210": "Sync HTTP call {} in async function, use `httpx.AsyncClient`.",
"TRIO211": (
"Likely sync HTTP call {} in async function, use `httpx.AsyncClient`."
Expand Down Expand Up @@ -114,7 +117,7 @@ def visit_blocking_call(self, node: ast.Call):

@error_class
class Visitor212(Visitor200):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO212": (
"Blocking sync HTTP call {1} on httpx object {0}, use httpx.AsyncClient."
)
Expand Down Expand Up @@ -166,7 +169,7 @@ def visit_blocking_call(self, node: ast.Call):
# Process invocations 202
@error_class
class Visitor22X(Visitor200):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO220": (
"Sync call {} in async function, use "
"`await nursery.start({}.run_process, ...)`."
Expand Down Expand Up @@ -225,7 +228,7 @@ def is_p_wait(arg: ast.expr) -> bool:

@error_class
class Visitor23X(Visitor200):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO230": "Sync call {0} in async function, use `{1}.open_file(...)`.",
"TRIO231": "Sync call {0} in async function, use `{1}.wrap_file({0})`.",
}
Expand All @@ -251,7 +254,7 @@ def visit_blocking_call(self, node: ast.Call):

@error_class
class Visitor232(Visitor200):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO232": (
"Blocking sync call {1} on file object {0}, wrap the file object"
"in `{2}.wrap_file()` to get an async file object."
Expand Down Expand Up @@ -281,7 +284,7 @@ def visit_blocking_call(self, node: ast.Call):

@error_class
class Visitor24X(Visitor200):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO240": "Avoid using os.path, prefer using {1}.Path objects.",
}

Expand Down
10 changes: 4 additions & 6 deletions flake8_trio/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
)

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence


# Statement injected at the start of loops to track missed checkpoints.
ARTIFICIAL_STATEMENT = Statement("artificial", -1)


Expand Down Expand Up @@ -226,7 +227,7 @@ def leave_Yield(
@error_class_cst
@disabled_by_default
class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
error_codes = {
error_codes: Mapping[str, str] = {
"TRIO910": (
"{0} from async function with no guaranteed checkpoint or exception "
"since function definition on line {1.lineno}."
Expand Down Expand Up @@ -591,10 +592,7 @@ def visit_While_body(self, node: cst.For | cst.While):
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
else:
# pyright correctly dislikes Statement defining __eq__ but not __hash__
# but it works:tm:, and changing it touches on various bits of code, so
# leaving it for another time.
self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT} # pyright: ignore
self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT}

self.loop_state.uncheckpointed_before_continue = set()
self.loop_state.uncheckpointed_before_break = set()
Expand Down
Loading

0 comments on commit bae0a63

Please sign in to comment.