Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ==-based narrowing of Optional #18163

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6274,10 +6274,6 @@ def has_no_custom_eq_checks(t: Type) -> bool:
coerce_only_in_literal_context,
)

# Strictly speaking, we should also skip this check if the objects in the expr
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
# assume nobody would actually create a custom objects that considers itself
# equal to None.
if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys()
Expand Down Expand Up @@ -6602,25 +6598,36 @@ def refine_away_none_in_comparison(
For more details about what the different arguments mean, see the
docstring of 'refine_identity_comparison_expression' up above.
"""

non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_overlapping_none(typ):
non_optional_types.append(typ)

# Make sure we have a mixture of optional and non-optional types.
if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices):
return {}, {}
if_map, else_map = {}, {}

if_map = {}
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)
if not non_optional_types or (len(non_optional_types) != len(chain_indices)):

return if_map, {}
# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
# convenient but is strictly not type-safe):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)

# Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
# so type-safe but less convenient, because e.g. `Optional[A] == None` still results
# in `Optional[A]`):
if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if is_overlapping_none(expr_type):
else_map[operands[i]] = remove_optional(expr_type)

return if_map, else_map

def is_len_of_tuple(self, expr: Expression) -> bool:
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1385,9 +1385,9 @@ val: Optional[A]
if val == None:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "__main__.A"
if val != None:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "__main__.A"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"

Expand Down
Loading