Skip to content

Commit

Permalink
Merge branch 'master' into narrowing/refine_partial_types_in_loops
Browse files Browse the repository at this point in the history
# Conflicts:
#	test-data/unit/check-narrowing.test
  • Loading branch information
tyralla committed Nov 23, 2024
2 parents c8eee51 + 499adae commit 8844183
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 48 deletions.
55 changes: 39 additions & 16 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from contextlib import contextmanager
from typing import DefaultDict, Iterator, List, Optional, Tuple, Union, cast
from typing import DefaultDict, Iterator, List, NamedTuple, Optional, Tuple, Union
from typing_extensions import TypeAlias as _TypeAlias

from mypy.erasetype import remove_instance_last_known_values
Expand Down Expand Up @@ -30,6 +30,11 @@
BindableExpression: _TypeAlias = Union[IndexExpr, MemberExpr, NameExpr]


class CurrentType(NamedTuple):
type: Type
from_assignment: bool


class Frame:
"""A Frame represents a specific point in the execution of a program.
It carries information about the current types of expressions at
Expand All @@ -44,7 +49,7 @@ class Frame:

def __init__(self, id: int, conditional_frame: bool = False) -> None:
self.id = id
self.types: dict[Key, Type] = {}
self.types: dict[Key, CurrentType] = {}
self.unreachable = False
self.conditional_frame = conditional_frame
self.suppress_unreachable_warnings = False
Expand Down Expand Up @@ -132,18 +137,18 @@ def push_frame(self, conditional_frame: bool = False) -> Frame:
self.options_on_return.append([])
return f

def _put(self, key: Key, type: Type, index: int = -1) -> None:
self.frames[index].types[key] = type
def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None:
self.frames[index].types[key] = CurrentType(type, from_assignment)

def _get(self, key: Key, index: int = -1) -> Type | None:
def _get(self, key: Key, index: int = -1) -> CurrentType | None:
if index < 0:
index += len(self.frames)
for i in range(index, -1, -1):
if key in self.frames[i].types:
return self.frames[i].types[key]
return None

def put(self, expr: Expression, typ: Type) -> None:
def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None:
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
return
if not literal(expr):
Expand All @@ -153,7 +158,7 @@ def put(self, expr: Expression, typ: Type) -> None:
if key not in self.declarations:
self.declarations[key] = get_declaration(expr)
self._add_dependencies(key)
self._put(key, typ)
self._put(key, typ, from_assignment)

def unreachable(self) -> None:
self.frames[-1].unreachable = True
Expand All @@ -164,7 +169,10 @@ def suppress_unreachable_warnings(self) -> None:
def get(self, expr: Expression) -> Type | None:
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried to get non-literal"
return self._get(key)
found = self._get(key)
if found is None:
return None
return found.type

def is_unreachable(self) -> bool:
# TODO: Copy the value of unreachable into new frames to avoid
Expand Down Expand Up @@ -193,7 +201,7 @@ def update_from_options(self, frames: list[Frame]) -> bool:
If a key is declared as AnyType, only update it if all the
options are the same.
"""

all_reachable = all(not f.unreachable for f in frames)
frames = [f for f in frames if not f.unreachable]
changed = False
keys = {key for f in frames for key in f.types}
Expand All @@ -207,17 +215,30 @@ def update_from_options(self, frames: list[Frame]) -> bool:
# know anything about key in at least one possible frame.
continue

type = resulting_values[0]
assert type is not None
if all_reachable and all(
x is not None and not x.from_assignment for x in resulting_values
):
# Do not synthesize a new type if we encountered a conditional block
# (if, while or match-case) without assignments.
# See check-isinstance.test::testNoneCheckDoesNotMakeTypeVarOptional
# This is a safe assumption: the fact that we checked something with `is`
# or `isinstance` does not change the type of the value.
continue

current_type = resulting_values[0]
assert current_type is not None
type = current_type.type
declaration_type = get_proper_type(self.declarations.get(key))
if isinstance(declaration_type, AnyType):
# At this point resulting values can't contain None, see continue above
if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]):
if not all(
t is not None and is_same_type(type, t.type) for t in resulting_values[1:]
):
type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type)
else:
for other in resulting_values[1:]:
assert other is not None
type = join_simple(self.declarations[key], type, other)
type = join_simple(self.declarations[key], type, other.type)
# Try simplifying resulting type for unions involving variadic tuples.
# Technically, everything is still valid without this step, but if we do
# not do this, this may create long unions after exiting an if check like:
Expand All @@ -236,8 +257,8 @@ def update_from_options(self, frames: list[Frame]) -> bool:
)
if simplified == self.declarations[key]:
type = simplified
if current_value is None or not is_same_type(type, current_value):
self._put(key, type)
if current_value is None or not is_same_type(type, current_value[0]):
self._put(key, type, from_assignment=True)
changed = True

self.frames[-1].unreachable = not frames
Expand Down Expand Up @@ -374,7 +395,9 @@ def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Ty
key = literal_hash(expr)
assert key is not None
enclosers = [get_declaration(expr)] + [
f.types[key] for f in self.frames if key in f.types and is_subtype(type, f.types[key])
f.types[key].type
for f in self.frames
if key in f.types and is_subtype(type, f.types[key][0])
]
return enclosers[-1]

Expand Down
27 changes: 14 additions & 13 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4728,11 +4728,11 @@ def visit_if_stmt(self, s: IfStmt) -> None:

# XXX Issue a warning if condition is always False?
with self.binder.frame_context(can_skip=True, fall_through=2):
self.push_type_map(if_map)
self.push_type_map(if_map, from_assignment=False)
self.accept(b)

# XXX Issue a warning if condition is always True?
self.push_type_map(else_map)
self.push_type_map(else_map, from_assignment=False)

with self.binder.frame_context(can_skip=False, fall_through=2):
if s.else_body:
Expand Down Expand Up @@ -5313,18 +5313,21 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
if b.is_unreachable or isinstance(
get_proper_type(pattern_type.type), UninhabitedType
):
self.push_type_map(None)
self.push_type_map(None, from_assignment=False)
else_map: TypeMap = {}
else:
pattern_map, else_map = conditional_types_to_typemaps(
named_subject, pattern_type.type, pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
self.push_type_map(pattern_map, from_assignment=False)
if pattern_map:
for expr, typ in pattern_map.items():
self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ))
self.push_type_map(pattern_type.captures)
self.push_type_map(
self._get_recursive_sub_patterns_map(expr, typ),
from_assignment=False,
)
self.push_type_map(pattern_type.captures, from_assignment=False)
if g is not None:
with self.binder.frame_context(can_skip=False, fall_through=3):
gt = get_proper_type(self.expr_checker.accept(g))
Expand All @@ -5350,11 +5353,11 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
continue
type_map[named_subject] = type_map[expr]

self.push_type_map(guard_map)
self.push_type_map(guard_map, from_assignment=False)
self.accept(b)
else:
self.accept(b)
self.push_type_map(else_map)
self.push_type_map(else_map, from_assignment=False)

# This is needed due to a quirk in frame_context. Without it types will stay narrowed
# after the match.
Expand Down Expand Up @@ -7375,12 +7378,12 @@ def iterable_item_type(
def function_type(self, func: FuncBase) -> FunctionLike:
return function_type(func, self.named_type("builtins.function"))

def push_type_map(self, type_map: TypeMap) -> None:
def push_type_map(self, type_map: TypeMap, *, from_assignment: bool = True) -> None:
if type_map is None:
self.binder.unreachable()
else:
for expr, type in type_map.items():
self.binder.put(expr, type)
self.binder.put(expr, type, from_assignment=from_assignment)

def infer_issubclass_maps(self, node: CallExpr, expr: Expression) -> tuple[TypeMap, TypeMap]:
"""Infer type restrictions for an expression in issubclass call."""
Expand Down Expand Up @@ -7753,9 +7756,7 @@ def conditional_types(
) and is_proper_subtype(current_type, proposed_type, ignore_promotions=True):
# Expression is always of one of the types in proposed_type_ranges
return default, UninhabitedType()
elif not is_overlapping_types(
current_type, proposed_type, prohibit_none_typevar_overlap=True, ignore_promotions=True
):
elif not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
# Expression is never of any type in proposed_type_ranges
return UninhabitedType(), default
else:
Expand Down
45 changes: 40 additions & 5 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,12 @@ def __init__(
# Used to pass information about current overload index to visit_func_def().
self.current_overload_item: int | None = None

# Used to track whether currently inside an except* block. This helps
# to invoke errors when continue/break/return is used inside except* block.
self.inside_except_star_block: bool = False
# Used to track edge case when return is still inside except* if it enters a loop
self.return_stmt_inside_except_star_block: bool = False

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -511,6 +517,25 @@ def allow_unbound_tvars_set(self) -> Iterator[None]:
finally:
self.allow_unbound_tvars = old

@contextmanager
def inside_except_star_block_set(
self, value: bool, entering_loop: bool = False
) -> Iterator[None]:
old = self.inside_except_star_block
self.inside_except_star_block = value

# Return statement would still be in except* scope if entering loops
if not entering_loop:
old_return_stmt_flag = self.return_stmt_inside_except_star_block
self.return_stmt_inside_except_star_block = value

try:
yield
finally:
self.inside_except_star_block = old
if not entering_loop:
self.return_stmt_inside_except_star_block = old_return_stmt_flag

#
# Preparing module (performed before semantic analysis)
#
Expand Down Expand Up @@ -877,7 +902,8 @@ def visit_func_def(self, defn: FuncDef) -> None:
return

with self.scope.function_scope(defn):
self.analyze_func_def(defn)
with self.inside_except_star_block_set(value=False):
self.analyze_func_def(defn)

def function_fullname(self, fullname: str) -> str:
if self.current_overload_item is None:
Expand Down Expand Up @@ -1684,6 +1710,7 @@ def visit_decorator(self, dec: Decorator) -> None:
"abc.abstractproperty",
"functools.cached_property",
"enum.property",
"types.DynamicClassAttribute",
),
):
removed.append(i)
Expand Down Expand Up @@ -5263,6 +5290,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> None:
self.statement = s
if not self.is_func_scope():
self.fail('"return" outside function', s)
if self.return_stmt_inside_except_star_block:
self.fail('"return" not allowed in except* block', s, serious=True)
if s.expr:
s.expr.accept(self)

Expand Down Expand Up @@ -5296,7 +5325,8 @@ def visit_while_stmt(self, s: WhileStmt) -> None:
self.statement = s
s.expr.accept(self)
self.loop_depth[-1] += 1
s.body.accept(self)
with self.inside_except_star_block_set(value=False, entering_loop=True):
s.body.accept(self)
self.loop_depth[-1] -= 1
self.visit_block_maybe(s.else_body)

Expand All @@ -5320,20 +5350,24 @@ def visit_for_stmt(self, s: ForStmt) -> None:
s.index_type = analyzed

self.loop_depth[-1] += 1
self.visit_block(s.body)
with self.inside_except_star_block_set(value=False, entering_loop=True):
self.visit_block(s.body)
self.loop_depth[-1] -= 1

self.visit_block_maybe(s.else_body)

def visit_break_stmt(self, s: BreakStmt) -> None:
self.statement = s
if self.loop_depth[-1] == 0:
self.fail('"break" outside loop', s, serious=True, blocker=True)
if self.inside_except_star_block:
self.fail('"break" not allowed in except* block', s, serious=True)

def visit_continue_stmt(self, s: ContinueStmt) -> None:
self.statement = s
if self.loop_depth[-1] == 0:
self.fail('"continue" outside loop', s, serious=True, blocker=True)
if self.inside_except_star_block:
self.fail('"continue" not allowed in except* block', s, serious=True)

def visit_if_stmt(self, s: IfStmt) -> None:
self.statement = s
Expand All @@ -5354,7 +5388,8 @@ def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor[None]) -> None:
type.accept(visitor)
if var:
self.analyze_lvalue(var)
handler.accept(visitor)
with self.inside_except_star_block_set(self.inside_except_star_block or s.is_star):
handler.accept(visitor)
if s.else_body:
s.else_body.accept(visitor)
if s.finally_body:
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def analyze_type_alias(
)
analyzer.in_dynamic_func = in_dynamic_func
analyzer.global_scope = global_scope
res = type.accept(analyzer)
res = analyzer.anal_type(type, nested=False)
return res, analyzer.aliases_used


Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ elif x is Foo.C:
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.C]"
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is "__main__.Foo"
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"

if Foo.A is x:
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.A]"
Expand All @@ -825,7 +825,7 @@ elif Foo.C is x:
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.C]"
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is "__main__.Foo"
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"

y: Foo
if y is Foo.A:
Expand Down
17 changes: 9 additions & 8 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2207,23 +2207,24 @@ def foo2(x: Optional[str]) -> None:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/isinstance.pyi]

[case testNoneCheckDoesNotNarrowWhenUsingTypeVars]

# Note: this test (and the following one) are testing checker.conditional_type_map:
# if you set the 'prohibit_none_typevar_overlap' keyword argument to False when calling
# 'is_overlapping_types', the binder will incorrectly infer that 'out' has a type of
# Union[T, None] after the if statement.

[case testNoneCheckDoesNotMakeTypeVarOptional]
from typing import TypeVar

T = TypeVar('T')

def foo(x: T) -> T:
def foo_if(x: T) -> T:
out = None
out = x
if out is None:
pass
return out

def foo_while(x: T) -> T:
out = None
out = x
while out is None:
pass
return out
[builtins fixtures/isinstance.pyi]

[case testNoneCheckDoesNotNarrowWhenUsingTypeVarsNoStrictOptional]
Expand Down
Loading

0 comments on commit 8844183

Please sign in to comment.