Skip to content

Commit

Permalink
Added checks for invalid usage of continue/break/return in except* bl…
Browse files Browse the repository at this point in the history
…ock (#18132)

Fixes #18123 

This PR addresses an issue where mypy incorrectly allows
break/continue/return statements in the except* block. (see
https://peps.python.org/pep-0654/#forbidden-combinations)
  • Loading branch information
coldwolverine authored Nov 21, 2024
1 parent 08340c2 commit 499adae
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 5 deletions.
44 changes: 39 additions & 5 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,12 @@ def __init__(
# Used to pass information about current overload index to visit_func_def().
self.current_overload_item: int | None = None

# Used to track whether currently inside an except* block. This helps
# to invoke errors when continue/break/return is used inside except* block.
self.inside_except_star_block: bool = False
# Used to track edge case when return is still inside except* if it enters a loop
self.return_stmt_inside_except_star_block: bool = False

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -511,6 +517,25 @@ def allow_unbound_tvars_set(self) -> Iterator[None]:
finally:
self.allow_unbound_tvars = old

@contextmanager
def inside_except_star_block_set(
self, value: bool, entering_loop: bool = False
) -> Iterator[None]:
old = self.inside_except_star_block
self.inside_except_star_block = value

# Return statement would still be in except* scope if entering loops
if not entering_loop:
old_return_stmt_flag = self.return_stmt_inside_except_star_block
self.return_stmt_inside_except_star_block = value

try:
yield
finally:
self.inside_except_star_block = old
if not entering_loop:
self.return_stmt_inside_except_star_block = old_return_stmt_flag

#
# Preparing module (performed before semantic analysis)
#
Expand Down Expand Up @@ -877,7 +902,8 @@ def visit_func_def(self, defn: FuncDef) -> None:
return

with self.scope.function_scope(defn):
self.analyze_func_def(defn)
with self.inside_except_star_block_set(value=False):
self.analyze_func_def(defn)

def function_fullname(self, fullname: str) -> str:
if self.current_overload_item is None:
Expand Down Expand Up @@ -5264,6 +5290,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> None:
self.statement = s
if not self.is_func_scope():
self.fail('"return" outside function', s)
if self.return_stmt_inside_except_star_block:
self.fail('"return" not allowed in except* block', s, serious=True)
if s.expr:
s.expr.accept(self)

Expand Down Expand Up @@ -5297,7 +5325,8 @@ def visit_while_stmt(self, s: WhileStmt) -> None:
self.statement = s
s.expr.accept(self)
self.loop_depth[-1] += 1
s.body.accept(self)
with self.inside_except_star_block_set(value=False, entering_loop=True):
s.body.accept(self)
self.loop_depth[-1] -= 1
self.visit_block_maybe(s.else_body)

Expand All @@ -5321,20 +5350,24 @@ def visit_for_stmt(self, s: ForStmt) -> None:
s.index_type = analyzed

self.loop_depth[-1] += 1
self.visit_block(s.body)
with self.inside_except_star_block_set(value=False, entering_loop=True):
self.visit_block(s.body)
self.loop_depth[-1] -= 1

self.visit_block_maybe(s.else_body)

def visit_break_stmt(self, s: BreakStmt) -> None:
self.statement = s
if self.loop_depth[-1] == 0:
self.fail('"break" outside loop', s, serious=True, blocker=True)
if self.inside_except_star_block:
self.fail('"break" not allowed in except* block', s, serious=True)

def visit_continue_stmt(self, s: ContinueStmt) -> None:
self.statement = s
if self.loop_depth[-1] == 0:
self.fail('"continue" outside loop', s, serious=True, blocker=True)
if self.inside_except_star_block:
self.fail('"continue" not allowed in except* block', s, serious=True)

def visit_if_stmt(self, s: IfStmt) -> None:
self.statement = s
Expand All @@ -5355,7 +5388,8 @@ def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor[None]) -> None:
type.accept(visitor)
if var:
self.analyze_lvalue(var)
handler.accept(visitor)
with self.inside_except_star_block_set(self.inside_except_star_block or s.is_star):
handler.accept(visitor)
if s.else_body:
s.else_body.accept(visitor)
if s.finally_body:
Expand Down
86 changes: 86 additions & 0 deletions test-data/unit/check-python311.test
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,89 @@ Alias4 = Callable[[*IntList], int] # E: "List[int]" cannot be unpacked (must be
x4: Alias4[int] # E: Bad number of arguments for type alias, expected 0, given 1
reveal_type(x4) # N: Revealed type is "def (*Any) -> builtins.int"
[builtins fixtures/tuple.pyi]

[case testReturnInExceptStarBlock1]
# flags: --python-version 3.11
def foo() -> None:
try:
pass
except* Exception:
return # E: "return" not allowed in except* block
finally:
return
[builtins fixtures/exception.pyi]

[case testReturnInExceptStarBlock2]
# flags: --python-version 3.11
def foo():
while True:
try:
pass
except* Exception:
while True:
return # E: "return" not allowed in except* block
[builtins fixtures/exception.pyi]

[case testContinueInExceptBlockNestedInExceptStarBlock]
# flags: --python-version 3.11
while True:
try:
...
except* Exception:
try:
...
except Exception:
continue # E: "continue" not allowed in except* block
continue # E: "continue" not allowed in except* block
[builtins fixtures/exception.pyi]

[case testReturnInExceptBlockNestedInExceptStarBlock]
# flags: --python-version 3.11
def foo():
try:
...
except* Exception:
try:
...
except Exception:
return # E: "return" not allowed in except* block
return # E: "return" not allowed in except* block
[builtins fixtures/exception.pyi]

[case testBreakContinueReturnInExceptStarBlock1]
# flags: --python-version 3.11
from typing import Iterable
def foo(x: Iterable[int]) -> None:
for _ in x:
try:
pass
except* Exception:
continue # E: "continue" not allowed in except* block
except* Exception:
for _ in x:
continue
break # E: "break" not allowed in except* block
except* Exception:
return # E: "return" not allowed in except* block
[builtins fixtures/exception.pyi]

[case testBreakContinueReturnInExceptStarBlock2]
# flags: --python-version 3.11
def foo():
while True:
try:
pass
except* Exception:
def inner():
while True:
if 1 < 1:
continue
else:
break
return
if 1 < 2:
break # E: "break" not allowed in except* block
if 1 < 2:
continue # E: "continue" not allowed in except* block
return # E: "return" not allowed in except* block
[builtins fixtures/exception.pyi]

0 comments on commit 499adae

Please sign in to comment.