Skip to content

Commit

Permalink
[red-knot] Support assert_type
Browse files Browse the repository at this point in the history
  • Loading branch information
InSyncWithFoo committed Jan 3, 2025
1 parent 706d87f commit 580f2a2
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import reveal_type

# `assert_type`

## Basic

```py
from typing_extensions import assert_type

def _(x: int):
assert_type(x, int) # fine
assert_type(x, str) # error: [type-assertion-failure]
```

## Narrowing

The asserted type is checked against the inferred type, not the declared type.

```py
from typing_extensions import assert_type

def _(x: int | str):
if isinstance(x, int):
reveal_type(x) # revealed: int
assert_type(x, int) # fine
```

## Equivalence

The actual type must match the asserted type precisely.

```py
from typing import Any, Type, Union
from typing_extensions import assert_type

# Subtype does not count
def _(x: bool):
assert_type(x, int) # error: [type-assertion-failure]

def _(a: type[int], b: type[Any]):
assert_type(a, type[Any]) # error: [type-assertion-failure]
assert_type(b, type[int]) # error: [type-assertion-failure]

# The expression constructing the type is not taken into account
def _(a: type[int]):
# TODO: Infer the second argument as a type expression
assert_type(a, Type[int]) # error: [type-assertion-failure]
```

## Gradual types

```py
from typing_extensions import Literal, assert_type

# Any and Unknown are considered equivalent
def _(a):
reveal_type(a) # revealed: Unknown
assert_type(a, Any) # fine

def _(b: type[Literal]): # TODO: Should be invalid
# TODO: Should be `type[Unknown]`
reveal_type(b) # revealed: @Todo(unsupported type[X] special form)

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(b, type[Any]) # error: [type-assertion-failure]
```

## Tuples

Tuple types with the same elements are the same.

```py
from typing_extensions import assert_type

def _(a: tuple[int, str, bytes]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure]

assert_type(a, tuple[int, str]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure]
assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure]

def _(a: tuple[Any, ...]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure]
```

## Unions

Unions with the same elements are the same, regardless of order.

```py
from typing_extensions import assert_type

def _(a: str | int):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, str | int) # error: [type-assertion-failure]
```

## Intersections

Intersections are the same when their positive and negative parts are respectively the same,
regardless of order.

```py
from typing_extensions import assert_type

class A: ...
class B: ...
class C: ...
class D: ...

def _(a: A):
if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D):
reveal_type(a) # revealed: A & B & ~C & ~D

# TODO: Use Python API to spell intersection type
# assert_type(a, B & A & ~D & ~C)
```
181 changes: 179 additions & 2 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::hash::Hash;
use rustc_hash::FxHasher;
use std::collections::HashSet;
use std::hash::{BuildHasherDefault, Hash};
use std::iter;

use context::InferContext;
use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound};
Expand Down Expand Up @@ -547,6 +550,8 @@ pub enum Type<'db> {
// TODO protocols, callable types, overloads, generics, type vars
}

type OrderedTypeSet<'a, 'db> = HashSet<&'a Type<'db>, BuildHasherDefault<FxHasher>>;

impl<'db> Type<'db> {
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Unknown)
Expand Down Expand Up @@ -1401,6 +1406,89 @@ impl<'db> Type<'db> {
}
}

/// Returns true if this type and `other` are "exactly the same".
///
/// This powers the `assert_type()` directive.
pub(crate) fn is_equals_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let equal = |(first, second): (&Type<'db>, &Type<'db>)| first.is_equals_to(db, *second);

match (self, other) {
(Type::Todo(_), Type::Todo(_)) => false,

(Type::Any | Type::Unknown, Type::Any | Type::Unknown)
| (Type::Never, Type::Never)
| (Type::AlwaysTruthy, Type::AlwaysTruthy)
| (Type::AlwaysFalsy, Type::AlwaysFalsy)
| (Type::LiteralString, Type::LiteralString) => true,

(Type::KnownInstance(first), Type::KnownInstance(second)) => first == second,

(Type::FunctionLiteral(first), Type::FunctionLiteral(second)) => first == second,

(Type::ModuleLiteral(first), Type::ModuleLiteral(second)) => {
first.module(db) == second.module(db)
}

(Type::IntLiteral(first), Type::IntLiteral(second)) => first == second,
(Type::BooleanLiteral(first), Type::BooleanLiteral(second)) => first == second,
(Type::StringLiteral(first), Type::StringLiteral(second)) => {
first.value(db).as_ref() == second.value(db).as_ref()
}
(Type::BytesLiteral(first), Type::BytesLiteral(second)) => {
first.value(db).as_ref() == second.value(db).as_ref()
}
(Type::SliceLiteral(first), Type::SliceLiteral(second)) => {
first.as_tuple(db) == second.as_tuple(db)
}

(Type::ClassLiteral(first), Type::ClassLiteral(second)) => first.class == second.class,
(Type::Instance(first), Type::Instance(second)) => first.class == second.class,

(Type::SubclassOf(first), Type::SubclassOf(second)) => {
match (first.base, second.base) {
(ClassBase::Todo(_), ClassBase::Todo(_)) => false,
(ClassBase::Any | ClassBase::Unknown, ClassBase::Any | ClassBase::Unknown) => {
true
}
(ClassBase::Class(first), ClassBase::Class(second)) => first == second,
_ => false,
}
}

(Type::Tuple(first), Type::Tuple(second)) => {
first.len(db) == second.len(db)
&& iter::zip(first.elements(db), second.elements(db)).all(equal)
}

(Type::Union(first), Type::Union(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);

if first_elements.len() != second_elements.len() {
return false;
}

let first_elements = first_elements.iter().collect::<OrderedTypeSet>();
let second_elements = second_elements.iter().collect::<OrderedTypeSet>();

iter::zip(first_elements, second_elements).all(equal)
}

(Type::Intersection(first), Type::Intersection(second)) => {
let first_positive = first.positive(db).iter().collect::<OrderedTypeSet>();
let first_negative = first.negative(db).iter().collect::<OrderedTypeSet>();

let second_positive = second.positive(db).iter().collect::<OrderedTypeSet>();
let second_negative = second.negative(db).iter().collect::<OrderedTypeSet>();

iter::zip(first_positive, second_positive).all(equal)
&& iter::zip(first_negative, second_negative).all(equal)
}

_ => false,
}
}

/// Return true if there is just a single inhabitant for this type.
///
/// Note: This function aims to have no false positives, but might return `false`
Expand Down Expand Up @@ -1818,6 +1906,21 @@ impl<'db> Type<'db> {
CallOutcome::callable(len_ty.unwrap_or(normal_return_ty))
}

Some(KnownFunction::AssertType) => {
let normal_return_ty = function_type.signature(db).return_ty;

let [actual_type, asserted] = arg_types else {
// TODO: Emit a diagnostic
return CallOutcome::callable(normal_return_ty);
};
let Ok(asserted_type) = asserted.in_type_expression(db) else {
// TODO: Emit a diagnostic
return CallOutcome::callable(normal_return_ty);
};

CallOutcome::asserted(normal_return_ty, *actual_type, asserted_type)
}

_ => CallOutcome::callable(function_type.signature(db).return_ty),
},

Expand Down Expand Up @@ -3083,13 +3186,15 @@ pub enum KnownFunction {

/// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check)
NoTypeCheck,
/// `typing(_extensions).assert_type`
AssertType,
}

impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> {
match self {
Self::ConstraintFunction(f) => Some(f),
Self::RevealType | Self::Len | Self::Final | Self::NoTypeCheck => None,
Self::RevealType | Self::Len | Self::Final | Self::NoTypeCheck | Self::AssertType => None,
}
}

Expand All @@ -3111,6 +3216,7 @@ impl KnownFunction {
"no_type_check" if definition.is_typing_definition(db) => {
Some(KnownFunction::NoTypeCheck)
}
"assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType),
_ => None,
}
}
Expand Down Expand Up @@ -4443,6 +4549,77 @@ pub(crate) mod tests {
assert!(!from.into_type(&db).is_fully_static(&db));
}

#[test_case(Ty::Any, Ty::Any)]
#[test_case(Ty::Unknown, Ty::Unknown)]
#[test_case(Ty::Any, Ty::Unknown)]
#[test_case(Ty::Never, Ty::Never)]
#[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)]
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))]
#[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))]
#[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
#[test_case(
Ty::Intersection {
pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")],
neg: vec![Ty::BuiltinInstance("bytes"), Ty::None]
},
Ty::Intersection {
pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")],
neg: vec![Ty::None, Ty::BuiltinInstance("bytes")]
}
)]
#[test_case(
Ty::Intersection {
pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])],
neg: vec![Ty::SubclassOfAny]
},
Ty::Intersection {
pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])],
neg: vec![Ty::SubclassOfUnknown]
}
)]
fn is_equals_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(a.is_equals_to(&db, b));
assert!(b.is_equals_to(&db, a));
}

#[test_case(Ty::Todo, Ty::Todo)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
fn is_not_equal_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(!a.is_equals_to(&db, b));
assert!(!b.is_equals_to(&db, a));
}

#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))]
Expand Down
Loading

0 comments on commit 580f2a2

Please sign in to comment.