diff --git a/mypy/semanal.py b/mypy/semanal.py index 9d2d6f5fa484..bad2dfbefae0 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -484,6 +484,12 @@ def __init__( # Used to pass information about current overload index to visit_func_def(). self.current_overload_item: int | None = None + # Used to track whether currently inside an except* block. This helps + # to invoke errors when continue/break/return is used inside except* block. + self.inside_except_star_block: bool = False + # Used to track edge case when return is still inside except* if it enters a loop + self.return_stmt_inside_except_star_block: bool = False + # mypyc doesn't properly handle implementing an abstractproperty # with a regular attribute so we make them properties @property @@ -511,6 +517,25 @@ def allow_unbound_tvars_set(self) -> Iterator[None]: finally: self.allow_unbound_tvars = old + @contextmanager + def inside_except_star_block_set( + self, value: bool, entering_loop: bool = False + ) -> Iterator[None]: + old = self.inside_except_star_block + self.inside_except_star_block = value + + # Return statement would still be in except* scope if entering loops + if not entering_loop: + old_return_stmt_flag = self.return_stmt_inside_except_star_block + self.return_stmt_inside_except_star_block = value + + try: + yield + finally: + self.inside_except_star_block = old + if not entering_loop: + self.return_stmt_inside_except_star_block = old_return_stmt_flag + # # Preparing module (performed before semantic analysis) # @@ -877,7 +902,8 @@ def visit_func_def(self, defn: FuncDef) -> None: return with self.scope.function_scope(defn): - self.analyze_func_def(defn) + with self.inside_except_star_block_set(value=False): + self.analyze_func_def(defn) def function_fullname(self, fullname: str) -> str: if self.current_overload_item is None: @@ -5264,6 +5290,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> None: self.statement = s if not self.is_func_scope(): self.fail('"return" outside function', s) + if self.return_stmt_inside_except_star_block: + self.fail('"return" not allowed in except* block', s, serious=True) if s.expr: s.expr.accept(self) @@ -5297,7 +5325,8 @@ def visit_while_stmt(self, s: WhileStmt) -> None: self.statement = s s.expr.accept(self) self.loop_depth[-1] += 1 - s.body.accept(self) + with self.inside_except_star_block_set(value=False, entering_loop=True): + s.body.accept(self) self.loop_depth[-1] -= 1 self.visit_block_maybe(s.else_body) @@ -5321,20 +5350,24 @@ def visit_for_stmt(self, s: ForStmt) -> None: s.index_type = analyzed self.loop_depth[-1] += 1 - self.visit_block(s.body) + with self.inside_except_star_block_set(value=False, entering_loop=True): + self.visit_block(s.body) self.loop_depth[-1] -= 1 - self.visit_block_maybe(s.else_body) def visit_break_stmt(self, s: BreakStmt) -> None: self.statement = s if self.loop_depth[-1] == 0: self.fail('"break" outside loop', s, serious=True, blocker=True) + if self.inside_except_star_block: + self.fail('"break" not allowed in except* block', s, serious=True) def visit_continue_stmt(self, s: ContinueStmt) -> None: self.statement = s if self.loop_depth[-1] == 0: self.fail('"continue" outside loop', s, serious=True, blocker=True) + if self.inside_except_star_block: + self.fail('"continue" not allowed in except* block', s, serious=True) def visit_if_stmt(self, s: IfStmt) -> None: self.statement = s @@ -5355,7 +5388,8 @@ def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor[None]) -> None: type.accept(visitor) if var: self.analyze_lvalue(var) - handler.accept(visitor) + with self.inside_except_star_block_set(self.inside_except_star_block or s.is_star): + handler.accept(visitor) if s.else_body: s.else_body.accept(visitor) if s.finally_body: diff --git a/test-data/unit/check-python311.test b/test-data/unit/check-python311.test index 28951824999f..6f4c540572b0 100644 --- a/test-data/unit/check-python311.test +++ b/test-data/unit/check-python311.test @@ -173,3 +173,89 @@ Alias4 = Callable[[*IntList], int] # E: "List[int]" cannot be unpacked (must be x4: Alias4[int] # E: Bad number of arguments for type alias, expected 0, given 1 reveal_type(x4) # N: Revealed type is "def (*Any) -> builtins.int" [builtins fixtures/tuple.pyi] + +[case testReturnInExceptStarBlock1] +# flags: --python-version 3.11 +def foo() -> None: + try: + pass + except* Exception: + return # E: "return" not allowed in except* block + finally: + return +[builtins fixtures/exception.pyi] + +[case testReturnInExceptStarBlock2] +# flags: --python-version 3.11 +def foo(): + while True: + try: + pass + except* Exception: + while True: + return # E: "return" not allowed in except* block +[builtins fixtures/exception.pyi] + +[case testContinueInExceptBlockNestedInExceptStarBlock] +# flags: --python-version 3.11 +while True: + try: + ... + except* Exception: + try: + ... + except Exception: + continue # E: "continue" not allowed in except* block + continue # E: "continue" not allowed in except* block +[builtins fixtures/exception.pyi] + +[case testReturnInExceptBlockNestedInExceptStarBlock] +# flags: --python-version 3.11 +def foo(): + try: + ... + except* Exception: + try: + ... + except Exception: + return # E: "return" not allowed in except* block + return # E: "return" not allowed in except* block +[builtins fixtures/exception.pyi] + +[case testBreakContinueReturnInExceptStarBlock1] +# flags: --python-version 3.11 +from typing import Iterable +def foo(x: Iterable[int]) -> None: + for _ in x: + try: + pass + except* Exception: + continue # E: "continue" not allowed in except* block + except* Exception: + for _ in x: + continue + break # E: "break" not allowed in except* block + except* Exception: + return # E: "return" not allowed in except* block +[builtins fixtures/exception.pyi] + +[case testBreakContinueReturnInExceptStarBlock2] +# flags: --python-version 3.11 +def foo(): + while True: + try: + pass + except* Exception: + def inner(): + while True: + if 1 < 1: + continue + else: + break + return + if 1 < 2: + break # E: "break" not allowed in except* block + if 1 < 2: + continue # E: "continue" not allowed in except* block + return # E: "return" not allowed in except* block +[builtins fixtures/exception.pyi]