Skip to content

Commit

Permalink
Support ==-based narrowing of Optional
Browse files Browse the repository at this point in the history
Closes #18135

This change implements the third approach mentioned in #18135, which is stricter than similar narrowings, as clarified by the new/modified code comments. Personally, I prefer this more stringent way but could also switch this PR to approach two if there is a consent that convenience is more important than type safety here.
  • Loading branch information
tyralla committed Nov 18, 2024
1 parent 8ef2197 commit 1086dcb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
37 changes: 22 additions & 15 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6274,10 +6274,6 @@ def has_no_custom_eq_checks(t: Type) -> bool:
coerce_only_in_literal_context,
)

# Strictly speaking, we should also skip this check if the objects in the expr
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
# assume nobody would actually create a custom objects that considers itself
# equal to None.
if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys()
Expand Down Expand Up @@ -6602,25 +6598,36 @@ def refine_away_none_in_comparison(
For more details about what the different arguments mean, see the
docstring of 'refine_identity_comparison_expression' up above.
"""

non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_overlapping_none(typ):
non_optional_types.append(typ)

# Make sure we have a mixture of optional and non-optional types.
if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices):
return {}, {}
if_map, else_map = {}, {}

if_map = {}
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)
if not non_optional_types or (len(non_optional_types) != len(chain_indices)):

return if_map, {}
# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
# convenient but is strictly not type-safe):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)

# Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
# so type-safe but less convenient, because e.g. `Optional[A] == None` still results
# in `Optional[A]`):
if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if is_overlapping_none(expr_type):
else_map[operands[i]] = remove_optional(expr_type)

return if_map, else_map

def is_len_of_tuple(self, expr: Expression) -> bool:
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1385,9 +1385,9 @@ val: Optional[A]
if val == None:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "__main__.A"
if val != None:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "__main__.A"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"

Expand Down

0 comments on commit 1086dcb

Please sign in to comment.