Skip to content

Commit

Permalink
[red-knot] Support assert_type
Browse files Browse the repository at this point in the history
  • Loading branch information
InSyncWithFoo committed Jan 9, 2025
1 parent d0b2bbd commit 41f6284
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# `assert_type`

## Basic

```py
from typing_extensions import assert_type

def _(x: int):
assert_type(x, int) # fine
assert_type(x, str) # error: [type-assertion-failure]
```

## Narrowing

The asserted type is checked against the inferred type, not the declared type.

```toml
[environment]
python-version = "3.10"
```

```py
from typing_extensions import assert_type

def _(x: int | str):
if isinstance(x, int):
reveal_type(x) # revealed: int
assert_type(x, int) # fine
```

## Equivalence

The actual type must match the asserted type precisely.

```py
from typing import Any, Type, Union
from typing_extensions import assert_type

# Subtype does not count
def _(x: bool):
assert_type(x, int) # error: [type-assertion-failure]

def _(a: type[int], b: type[Any]):
assert_type(a, type[Any]) # error: [type-assertion-failure]
assert_type(b, type[int]) # error: [type-assertion-failure]

# The expression constructing the type is not taken into account
def _(a: type[int]):
# TODO: Infer the second argument as a type expression
assert_type(a, Type[int]) # error: [type-assertion-failure]
```

## Gradual types

```py
from typing import Any
from typing_extensions import Literal, assert_type

from knot_extensions import Unknown

# Any and Unknown are considered equivalent
def _(a: Unknown, b: Any):
reveal_type(a) # revealed: Unknown
assert_type(a, Any) # fine

reveal_type(b) # revealed: Any
assert_type(b, Unknown) # fine

def _(a: type[Unknown], b: type[Any]):
# TODO: Should be `type[Unknown]`
reveal_type(a) # revealed: @Todo(unsupported type[X] special form)
reveal_type(b) # revealed: type[Any]

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, type[Unknown]) # error: [type-assertion-failure]
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(b, type[Any]) # error: [type-assertion-failure]
```

## Tuples

Tuple types with the same elements are the same.

```py
from typing_extensions import assert_type

def _(a: tuple[int, str, bytes]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure]

assert_type(a, tuple[int, str]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure]
assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure]

def _(a: tuple[Any, ...]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure]
```

## Unions

Unions with the same elements are the same, regardless of order.

```toml
[environment]
python-version = "3.10"
```

```py
from typing_extensions import assert_type

def _(a: str | int):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, int | str) # error: [type-assertion-failure]
```

## Intersections

Intersections are the same when their positive and negative parts are respectively the same,
regardless of order.

```py
from typing_extensions import assert_type

from knot_extensions import Intersection, Not

class A: ...
class B: ...
class C: ...
class D: ...

def _(a: A):
if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D):
reveal_type(a) # revealed: A & B & ~C & ~D

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, Intersection[B, A, Not[D], Not[C]]) # error: [type-assertion-failure]
```
176 changes: 163 additions & 13 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::hash::Hash;
use std::iter;

use context::InferContext;
use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound};
Expand Down Expand Up @@ -1091,6 +1092,84 @@ impl<'db> Type<'db> {
)
}

/// Returns true if this type and `other` are gradual equivalent.
///
/// > Two gradual types `A` and `B` are equivalent
/// > (that is, the same gradual type, not merely consistent with one another)
/// > if and only if all materializations of `A` are also materializations of `B`,
/// > and all materializations of `B` are also materializations of `A`.
/// >
/// > &mdash; [Summary of type relations]
///
/// This powers the `assert_type()` directive.
///
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let equivalent =
|(first, second): (&Type<'db>, &Type<'db>)| first.is_gradual_equivalent_to(db, *second);

match (self, other) {
(_, _) if self == other => true,

(Type::Any | Type::Unknown, Type::Any | Type::Unknown) => true,

(Type::Instance(instance), Type::SubclassOf(subclass))
| (Type::SubclassOf(subclass), Type::Instance(instance)) => {
let Some(base_class) = subclass.subclass_of().into_class() else {
return false;
};

instance.class.is_known(db, KnownClass::Type)
&& base_class.is_known(db, KnownClass::Object)
}

(Type::SubclassOf(first), Type::SubclassOf(second)) => {
match (first.subclass_of(), second.subclass_of()) {
(first, second) if first == second => true,
(ClassBase::Any | ClassBase::Unknown, ClassBase::Any | ClassBase::Unknown) => {
true
}
_ => false,
}
}

(Type::Tuple(first), Type::Tuple(second)) => {
first.len(db) == second.len(db)
&& iter::zip(first.elements(db), second.elements(db)).all(equivalent)
}

(Type::Union(first), Type::Union(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);

if first_elements.len() != second_elements.len() {
return false;
}

iter::zip(first_elements, second_elements).all(equivalent)
}

(Type::Intersection(first), Type::Intersection(second)) => {
let first_positive = first.positive(db);
let first_negative = first.negative(db);

let second_positive = second.positive(db);
let second_negative = second.negative(db);

if first_positive.len() != second_positive.len()
|| first_negative.len() != second_negative.len()
{
return false;
}

iter::zip(first_positive, second_positive).all(equivalent)
&& iter::zip(first_negative, second_negative).all(equivalent)
}

_ => false,
}
}

/// Return true if this type and `other` have no common elements.
///
/// Note: This function aims to have no false positives, but might return
Expand Down Expand Up @@ -3242,6 +3321,9 @@ pub enum KnownFunction {
/// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check)
NoTypeCheck,

/// `typing(_extensions).assert_type`
AssertType,

/// `knot_extensions.static_assert`
StaticAssert,
/// `knot_extensions.is_equivalent_to`
Expand All @@ -3264,18 +3346,7 @@ impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> {
match self {
Self::ConstraintFunction(f) => Some(f),
Self::RevealType
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::StaticAssert
| Self::IsEquivalentTo
| Self::IsSubtypeOf
| Self::IsAssignableTo
| Self::IsDisjointFrom
| Self::IsFullyStatic
| Self::IsSingleton
| Self::IsSingleValued => None,
_ => None,
}
}

Expand All @@ -3297,6 +3368,7 @@ impl KnownFunction {
"no_type_check" if definition.is_typing_definition(db) => {
Some(KnownFunction::NoTypeCheck)
}
"assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType),
"static_assert" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::StaticAssert)
}
Expand Down Expand Up @@ -3662,7 +3734,8 @@ impl<'db> Class<'db> {
// does not accept the right arguments
CallOutcome::Callable { binding }
| CallOutcome::RevealType { binding, .. }
| CallOutcome::StaticAssertionError { binding, .. } => Ok(binding.return_ty()),
| CallOutcome::StaticAssertionError { binding, .. }
| CallOutcome::AssertType { binding, .. } => Ok(binding.return_ty()),
};

return return_ty_result.map(|ty| ty.to_meta_type(db));
Expand Down Expand Up @@ -4627,6 +4700,83 @@ pub(crate) mod tests {
assert!(!from.into_type(&db).is_fully_static(&db));
}

#[test_case(Ty::Todo, Ty::Todo)]
#[test_case(Ty::Any, Ty::Any)]
#[test_case(Ty::Unknown, Ty::Unknown)]
#[test_case(Ty::Any, Ty::Unknown)]
#[test_case(Ty::Never, Ty::Never)]
#[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)]
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))]
#[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))]
#[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)]
#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfBuiltinClass("object"))]
// TODO: Compare unions/intersections with different orders
// #[test_case(
// Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
// Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
// )]
// #[test_case(
// Ty::Intersection {
// pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")],
// neg: vec![Ty::BuiltinInstance("bytes"), Ty::None]
// },
// Ty::Intersection {
// pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")],
// neg: vec![Ty::None, Ty::BuiltinInstance("bytes")]
// }
// )]
// #[test_case(
// Ty::Intersection {
// pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])],
// neg: vec![Ty::SubclassOfAny]
// },
// Ty::Intersection {
// pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])],
// neg: vec![Ty::SubclassOfUnknown]
// }
// )]
fn is_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(a.is_gradual_equivalent_to(&db, b));
assert!(b.is_gradual_equivalent_to(&db, a));
}

#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfAny)]
#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfUnknown)]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfAny)]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfUnknown)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
fn is_not_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(!a.is_gradual_equivalent_to(&db, b));
assert!(!b.is_gradual_equivalent_to(&db, a));
}

#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))]
Expand Down
Loading

0 comments on commit 41f6284

Please sign in to comment.