Skip to content

Commit

Permalink
[red-knot] Support assert_type
Browse files Browse the repository at this point in the history
  • Loading branch information
InSyncWithFoo committed Jan 8, 2025
1 parent 339167d commit bf4957f
Show file tree
Hide file tree
Showing 4 changed files with 383 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# `assert_type`

## Basic

```py
from typing_extensions import assert_type

def _(x: int):
assert_type(x, int) # fine
assert_type(x, str) # error: [type-assertion-failure]
```

## Narrowing

The asserted type is checked against the inferred type, not the declared type.

```toml
[environment]
python-version = "3.10"
```

```py
from typing_extensions import assert_type

def _(x: int | str):
if isinstance(x, int):
reveal_type(x) # revealed: int
assert_type(x, int) # fine
```

## Equivalence

The actual type must match the asserted type precisely.

```py
from typing import Any, Type, Union
from typing_extensions import assert_type

# Subtype does not count
def _(x: bool):
assert_type(x, int) # error: [type-assertion-failure]

def _(a: type[int], b: type[Any]):
assert_type(a, type[Any]) # error: [type-assertion-failure]
assert_type(b, type[int]) # error: [type-assertion-failure]

# The expression constructing the type is not taken into account
def _(a: type[int]):
# TODO: Infer the second argument as a type expression
assert_type(a, Type[int]) # error: [type-assertion-failure]
```

## Gradual types

```py
from typing import Any
from typing_extensions import Literal, assert_type

from knot_extensions import Unknown

# Any and Unknown are considered equivalent
def _(a: Unknown, b: Any):
reveal_type(a) # revealed: Unknown
assert_type(a, Any) # fine

reveal_type(b) # revealed: Any
assert_type(b, Unknown) # fine

def _(a: type[Unknown], b: type[Any]):
# TODO: Should be `type[Unknown]`
reveal_type(a) # revealed: @Todo(unsupported type[X] special form)
reveal_type(b) # revealed: type[Any]

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, type[Unknown]) # error: [type-assertion-failure]
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(b, type[Any]) # error: [type-assertion-failure]
```

## Tuples

Tuple types with the same elements are the same.

```py
from typing_extensions import assert_type

def _(a: tuple[int, str, bytes]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure]

assert_type(a, tuple[int, str]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure]
assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure]

def _(a: tuple[Any, ...]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure]
```

## Unions

Unions with the same elements are the same, regardless of order.

```toml
[environment]
python-version = "3.10"
```

```py
from typing_extensions import assert_type

def _(a: str | int):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, int | str) # error: [type-assertion-failure]
```

## Intersections

Intersections are the same when their positive and negative parts are respectively the same,
regardless of order.

```py
from typing_extensions import assert_type

from knot_extensions import Intersection, Not

class A: ...
class B: ...
class C: ...
class D: ...

def _(a: A):
if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D):
reveal_type(a) # revealed: A & B & ~C & ~D

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, Intersection[B, A, Not[D], Not[C]]) # error: [type-assertion-failure]
```
190 changes: 177 additions & 13 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::hash::Hash;
use rustc_hash::FxHasher;
use std::collections::HashSet;
use std::hash::{BuildHasherDefault, Hash};
use std::iter;

use context::InferContext;
use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound};
Expand Down Expand Up @@ -555,6 +558,8 @@ pub enum Type<'db> {
// TODO protocols, callable types, overloads, generics, type vars
}

type OrderedTypeSet<'a, 'db> = HashSet<&'a Type<'db>, BuildHasherDefault<FxHasher>>;

impl<'db> Type<'db> {
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Unknown)
Expand Down Expand Up @@ -1431,6 +1436,88 @@ impl<'db> Type<'db> {
}
}

/// Returns true if this type and `other` are gradual equivalent.
///
/// > Two gradual types `A` and `B` are equivalent
/// > (that is, the same gradual type, not merely consistent with one another)
/// > if and only if all materializations of `A` are also materializations of `B`,
/// > and all materializations of `B` are also materializations of `A`.
/// >
/// > &mdash; [Summary of type relations]
///
/// Note: `Todo != Todo`.
///
/// This powers the `assert_type()` directive.
///
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let equivalent =
|(first, second): (&Type<'db>, &Type<'db>)| first.is_gradual_equivalent_to(db, *second);

match (self, other) {
(Type::Todo(_), Type::Todo(_)) => false,

(_, _) if self == other => true,

(Type::Any | Type::Unknown, Type::Any | Type::Unknown) => true,

(Type::SubclassOf(first), Type::SubclassOf(second)) => {
match (first.subclass_of(), second.subclass_of()) {
(ClassBase::Todo(_), ClassBase::Todo(_)) => false,
(ClassBase::Any | ClassBase::Unknown, ClassBase::Any | ClassBase::Unknown) => {
true
}
(ClassBase::Class(first), ClassBase::Class(second)) => first == second,
_ => false,
}
}

(Type::Tuple(first), Type::Tuple(second)) => {
first.len(db) == second.len(db)
&& iter::zip(first.elements(db), second.elements(db)).all(equivalent)
}

(Type::Union(first), Type::Union(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);

if first_elements.len() != second_elements.len() {
return false;
}

let first_elements = first_elements.iter().collect::<OrderedTypeSet>();
let second_elements = second_elements.iter().collect::<OrderedTypeSet>();

iter::zip(first_elements, second_elements).all(equivalent)
}

(Type::Intersection(first), Type::Intersection(second)) => {
let first_positive = first.positive(db);
let first_negative = first.negative(db);

let second_positive = second.positive(db);
let second_negative = second.negative(db);

if first_positive.len() != second_positive.len()
|| first_negative.len() != second_negative.len()
{
return false;
}

let first_positive = first_positive.iter().collect::<OrderedTypeSet>();
let first_negative = first_negative.iter().collect::<OrderedTypeSet>();

let second_positive = second_positive.iter().collect::<OrderedTypeSet>();
let second_negative = second_negative.iter().collect::<OrderedTypeSet>();

iter::zip(first_positive, second_positive).all(equivalent)
&& iter::zip(first_negative, second_negative).all(equivalent)
}

_ => false,
}
}

/// Return true if there is just a single inhabitant for this type.
///
/// Note: This function aims to have no false positives, but might return `false`
Expand Down Expand Up @@ -1923,6 +2010,19 @@ impl<'db> Type<'db> {
CallOutcome::callable(binding)
}

Some(KnownFunction::AssertType) => {
let [_actual_type, asserted] = binding.parameter_tys() else {
return CallOutcome::callable(binding);
};

// TODO: Infer this as a type expression directly
let Ok(asserted_type) = asserted.in_type_expression(db) else {
return CallOutcome::callable(binding);
};

CallOutcome::asserted(binding, asserted_type)
}

_ => CallOutcome::callable(binding),
}
}
Expand Down Expand Up @@ -3239,6 +3339,9 @@ pub enum KnownFunction {
/// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check)
NoTypeCheck,

/// `typing(_extensions).assert_type`
AssertType,

/// `knot_extensions.static_assert`
StaticAssert,
/// `knot_extensions.is_equivalent_to`
Expand All @@ -3261,18 +3364,7 @@ impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> {
match self {
Self::ConstraintFunction(f) => Some(f),
Self::RevealType
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::StaticAssert
| Self::IsEquivalentTo
| Self::IsSubtypeOf
| Self::IsAssignableTo
| Self::IsDisjointFrom
| Self::IsFullyStatic
| Self::IsSingleton
| Self::IsSingleValued => None,
_ => None,
}
}

Expand All @@ -3294,6 +3386,7 @@ impl KnownFunction {
"no_type_check" if definition.is_typing_definition(db) => {
Some(KnownFunction::NoTypeCheck)
}
"assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType),
"static_assert" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::StaticAssert)
}
Expand Down Expand Up @@ -4648,6 +4741,77 @@ pub(crate) mod tests {
assert!(!from.into_type(&db).is_fully_static(&db));
}

#[test_case(Ty::Any, Ty::Any)]
#[test_case(Ty::Unknown, Ty::Unknown)]
#[test_case(Ty::Any, Ty::Unknown)]
#[test_case(Ty::Never, Ty::Never)]
#[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)]
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))]
#[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))]
#[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
#[test_case(
Ty::Intersection {
pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")],
neg: vec![Ty::BuiltinInstance("bytes"), Ty::None]
},
Ty::Intersection {
pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")],
neg: vec![Ty::None, Ty::BuiltinInstance("bytes")]
}
)]
#[test_case(
Ty::Intersection {
pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])],
neg: vec![Ty::SubclassOfAny]
},
Ty::Intersection {
pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])],
neg: vec![Ty::SubclassOfUnknown]
}
)]
fn is_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(a.is_gradual_equivalent_to(&db, b));
assert!(b.is_gradual_equivalent_to(&db, a));
}

#[test_case(Ty::Todo, Ty::Todo)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
fn is_not_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(!a.is_gradual_equivalent_to(&db, b));
assert!(!b.is_gradual_equivalent_to(&db, a));
}

#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))]
Expand Down
Loading

0 comments on commit bf4957f

Please sign in to comment.