From 7bb9545224b736d730824280e68ab01630254d5c Mon Sep 17 00:00:00 2001 From: InSyncWithFoo Date: Wed, 8 Jan 2025 09:11:00 +0000 Subject: [PATCH] [red-knot] Support `assert_type` --- .../mdtest/directives/assert_type.md | 144 ++++++++++++++ crates/red_knot_python_semantic/src/types.rs | 180 ++++++++++++++++-- .../src/types/call.rs | 40 +++- .../src/types/diagnostic.rs | 23 +++ 4 files changed, 373 insertions(+), 14 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md diff --git a/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md b/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md new file mode 100644 index 0000000000000..54d1f98fe816c --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md @@ -0,0 +1,144 @@ +# `assert_type` + +## Basic + +```py +from typing_extensions import assert_type + +def _(x: int): + assert_type(x, int) # fine + assert_type(x, str) # error: [type-assertion-failure] +``` + +## Narrowing + +The asserted type is checked against the inferred type, not the declared type. + +```toml +[environment] +python-version = "3.10" +``` + +```py +from typing_extensions import assert_type + +def _(x: int | str): + if isinstance(x, int): + reveal_type(x) # revealed: int + assert_type(x, int) # fine +``` + +## Equivalence + +The actual type must match the asserted type precisely. + +```py +from typing import Any, Type, Union +from typing_extensions import assert_type + +# Subtype does not count +def _(x: bool): + assert_type(x, int) # error: [type-assertion-failure] + +def _(a: type[int], b: type[Any]): + assert_type(a, type[Any]) # error: [type-assertion-failure] + assert_type(b, type[int]) # error: [type-assertion-failure] + +# The expression constructing the type is not taken into account +def _(a: type[int]): + # TODO: Infer the second argument as a type expression + assert_type(a, Type[int]) # error: [type-assertion-failure] +``` + +## Gradual types + +```py +from typing import Any +from typing_extensions import Literal, assert_type + +from knot_extensions import Unknown + +# Any and Unknown are considered equivalent +def _(a: Unknown, b: Any): + reveal_type(a) # revealed: Unknown + assert_type(a, Any) # fine + + reveal_type(b) # revealed: Any + assert_type(b, Unknown) # fine + +def _(a: type[Unknown], b: type[Any]): + # TODO: Should be `type[Unknown]` + reveal_type(a) # revealed: @Todo(unsupported type[X] special form) + reveal_type(b) # revealed: type[Any] + + # TODO: Infer the second argument as a type expression + # Should be fine + assert_type(a, type[Unknown]) # error: [type-assertion-failure] + # TODO: Infer the second argument as a type expression + # Should be fine + assert_type(b, type[Any]) # error: [type-assertion-failure] +``` + +## Tuples + +Tuple types with the same elements are the same. + +```py +from typing_extensions import assert_type + +def _(a: tuple[int, str, bytes]): + # TODO: Infer the second argument as a type expression + # Should be fine + assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure] + + assert_type(a, tuple[int, str]) # error: [type-assertion-failure] + assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure] + assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure] + +def _(a: tuple[Any, ...]): + # TODO: Infer the second argument as a type expression + # Should be fine + assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure] +``` + +## Unions + +Unions with the same elements are the same, regardless of order. + +```toml +[environment] +python-version = "3.10" +``` + +```py +from typing_extensions import assert_type + +def _(a: str | int): + # TODO: Infer the second argument as a type expression + # Should be fine + assert_type(a, int | str) # error: [type-assertion-failure] +``` + +## Intersections + +Intersections are the same when their positive and negative parts are respectively the same, +regardless of order. + +```py +from typing_extensions import assert_type + +from knot_extensions import Intersection, Not + +class A: ... +class B: ... +class C: ... +class D: ... + +def _(a: A): + if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D): + reveal_type(a) # revealed: A & B & ~C & ~D + + # TODO: Infer the second argument as a type expression + # Should be fine + assert_type(a, Intersection[B, A, Not[D], Not[C]]) # error: [type-assertion-failure] +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index d259a2d641a28..bfdd4b2fd2517 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,4 +1,7 @@ -use std::hash::Hash; +use rustc_hash::FxHasher; +use std::collections::HashSet; +use std::hash::{BuildHasherDefault, Hash}; +use std::iter; use context::InferContext; use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound}; @@ -553,6 +556,8 @@ pub enum Type<'db> { // TODO protocols, callable types, overloads, generics, type vars } +type OrderedTypeSet<'a, 'db> = HashSet<&'a Type<'db>, BuildHasherDefault>; + impl<'db> Type<'db> { pub const fn is_unknown(&self) -> bool { matches!(self, Type::Unknown) @@ -1380,6 +1385,78 @@ impl<'db> Type<'db> { } } + /// Returns true if this type and `other` are gradual equivalent. + /// + /// > Two gradual types `A` and `B` are equivalent + /// > (that is, the same gradual type, not merely consistent with one another) + /// > if and only if all materializations of `A` are also materializations of `B`, + /// > and all materializations of `B` are also materializations of `A`. + /// > + /// > — [Summary of type relations] + /// + /// Note: `Todo != Todo`. + /// + /// This powers the `assert_type()` directive. + /// + /// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations + pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { + let equivalent = + |(first, second): (&Type<'db>, &Type<'db>)| first.is_gradual_equivalent_to(db, *second); + + if self == other { + return true; + } + + match (self, other) { + (Type::Todo(_), Type::Todo(_)) => false, + + (Type::Any | Type::Unknown, Type::Any | Type::Unknown) => true, + + (Type::SubclassOf(first), Type::SubclassOf(second)) => { + match (first.subclass_of(), second.subclass_of()) { + (ClassBase::Todo(_), ClassBase::Todo(_)) => false, + (ClassBase::Any | ClassBase::Unknown, ClassBase::Any | ClassBase::Unknown) => { + true + } + (ClassBase::Class(first), ClassBase::Class(second)) => first == second, + _ => false, + } + } + + (Type::Tuple(first), Type::Tuple(second)) => { + first.len(db) == second.len(db) + && iter::zip(first.elements(db), second.elements(db)).all(equivalent) + } + + (Type::Union(first), Type::Union(second)) => { + let first_elements = first.elements(db); + let second_elements = second.elements(db); + + if first_elements.len() != second_elements.len() { + return false; + } + + let first_elements = first_elements.iter().collect::(); + let second_elements = second_elements.iter().collect::(); + + iter::zip(first_elements, second_elements).all(equivalent) + } + + (Type::Intersection(first), Type::Intersection(second)) => { + let first_positive = first.positive(db).iter().collect::(); + let first_negative = first.negative(db).iter().collect::(); + + let second_positive = second.positive(db).iter().collect::(); + let second_negative = second.negative(db).iter().collect::(); + + iter::zip(first_positive, second_positive).all(equivalent) + && iter::zip(first_negative, second_negative).all(equivalent) + } + + _ => false, + } + } + /// Return true if there is just a single inhabitant for this type. /// /// Note: This function aims to have no false positives, but might return `false` @@ -1862,6 +1939,19 @@ impl<'db> Type<'db> { CallOutcome::callable(binding) } + Some(KnownFunction::AssertType) => { + let [_actual_type, asserted] = binding.parameter_tys() else { + return CallOutcome::callable(binding); + }; + + // TODO: Infer this as a type expression directly + let Ok(asserted_type) = asserted.in_type_expression(db) else { + return CallOutcome::callable(binding); + }; + + CallOutcome::asserted(binding, asserted_type) + } + _ => CallOutcome::callable(binding), } } @@ -3178,6 +3268,9 @@ pub enum KnownFunction { /// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check) NoTypeCheck, + /// `typing(_extensions).assert_type` + AssertType, + /// `knot_extensions.static_assert` StaticAssert, /// `knot_extensions.is_equivalent_to` @@ -3200,18 +3293,7 @@ impl KnownFunction { pub fn constraint_function(self) -> Option { match self { Self::ConstraintFunction(f) => Some(f), - Self::RevealType - | Self::Len - | Self::Final - | Self::NoTypeCheck - | Self::StaticAssert - | Self::IsEquivalentTo - | Self::IsSubtypeOf - | Self::IsAssignableTo - | Self::IsDisjointFrom - | Self::IsFullyStatic - | Self::IsSingleton - | Self::IsSingleValued => None, + _ => None, } } @@ -3233,6 +3315,7 @@ impl KnownFunction { "no_type_check" if definition.is_typing_definition(db) => { Some(KnownFunction::NoTypeCheck) } + "assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType), "static_assert" if definition.is_knot_extensions_definition(db) => { Some(KnownFunction::StaticAssert) } @@ -4585,6 +4668,77 @@ pub(crate) mod tests { assert!(!from.into_type(&db).is_fully_static(&db)); } + #[test_case(Ty::Any, Ty::Any)] + #[test_case(Ty::Unknown, Ty::Unknown)] + #[test_case(Ty::Any, Ty::Unknown)] + #[test_case(Ty::Never, Ty::Never)] + #[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)] + #[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)] + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))] + #[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))] + #[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))] + #[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))] + #[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)] + #[test_case( + Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]) + )] + #[test_case( + Ty::Intersection { + pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")], + neg: vec![Ty::BuiltinInstance("bytes"), Ty::None] + }, + Ty::Intersection { + pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")], + neg: vec![Ty::None, Ty::BuiltinInstance("bytes")] + } + )] + #[test_case( + Ty::Intersection { + pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])], + neg: vec![Ty::SubclassOfAny] + }, + Ty::Intersection { + pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])], + neg: vec![Ty::SubclassOfUnknown] + } + )] + fn is_gradual_equivalent_to(a: Ty, b: Ty) { + let db = setup_db(); + let a = a.into_type(&db); + let b = b.into_type(&db); + + assert!(a.is_gradual_equivalent_to(&db, b)); + assert!(b.is_gradual_equivalent_to(&db, a)); + } + + #[test_case(Ty::Todo, Ty::Todo)] + #[test_case( + Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")]) + )] + #[test_case( + Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]), + Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")]) + )] + #[test_case( + Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]) + )] + #[test_case( + Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]) + )] + fn is_not_gradual_equivalent_to(a: Ty, b: Ty) { + let db = setup_db(); + let a = a.into_type(&db); + let b = b.into_type(&db); + + assert!(!a.is_gradual_equivalent_to(&db, b)); + assert!(!b.is_gradual_equivalent_to(&db, a)); + } + #[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")] #[test_case(Ty::IntLiteral(-1))] #[test_case(Ty::StringLiteral("foo"))] diff --git a/crates/red_knot_python_semantic/src/types/call.rs b/crates/red_knot_python_semantic/src/types/call.rs index 5b5e4d978d76a..6995f14d290d1 100644 --- a/crates/red_knot_python_semantic/src/types/call.rs +++ b/crates/red_knot_python_semantic/src/types/call.rs @@ -1,5 +1,5 @@ use super::context::InferContext; -use super::diagnostic::CALL_NON_CALLABLE; +use super::diagnostic::{CALL_NON_CALLABLE, TYPE_ASSERTION_FAILURE}; use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder}; use crate::types::diagnostic::STATIC_ASSERT_ERROR; use crate::types::Truthiness; @@ -39,6 +39,10 @@ pub(super) enum CallOutcome<'db> { truthiness: Truthiness, message: Option<&'db str>, }, + AssertType { + binding: CallBinding<'db>, + asserted_ty: Type<'db>, + }, } impl<'db> CallOutcome<'db> { @@ -71,6 +75,14 @@ impl<'db> CallOutcome<'db> { } } + /// Create a new `CallOutcome::AssertType` with given revealed and return types. + pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> { + CallOutcome::AssertType { + binding, + asserted_ty, + } + } + /// Get the return type of the call, or `None` if not callable. pub(super) fn return_ty(&self, db: &'db dyn Db) -> Option> { match self { @@ -98,6 +110,10 @@ impl<'db> CallOutcome<'db> { .map(UnionBuilder::build), Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_ty(db), Self::StaticAssertionError { .. } => Some(Type::none(db)), + Self::AssertType { + binding, + asserted_ty: _, + } => Some(binding.return_ty()), } } @@ -341,6 +357,28 @@ impl<'db> CallOutcome<'db> { Ok(Type::Unknown) } + CallOutcome::AssertType { + binding, + asserted_ty, + } => { + let [actual_ty, _asserted] = binding.parameter_tys() else { + return Ok(binding.return_ty()); + }; + + if !actual_ty.is_gradual_equivalent_to(context.db(), *asserted_ty) { + context.report_lint( + &TYPE_ASSERTION_FAILURE, + node, + format_args!( + "Actual type `{}` is not the same as asserted type `{}`", + actual_ty.display(context.db()), + asserted_ty.display(context.db()), + ), + ); + } + + Ok(binding.return_ty()) + } } } } diff --git a/crates/red_knot_python_semantic/src/types/diagnostic.rs b/crates/red_knot_python_semantic/src/types/diagnostic.rs index 7fe097ce4eadd..c89963bcbeb55 100644 --- a/crates/red_knot_python_semantic/src/types/diagnostic.rs +++ b/crates/red_knot_python_semantic/src/types/diagnostic.rs @@ -48,6 +48,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) { registry.register_lint(&POSSIBLY_UNBOUND_IMPORT); registry.register_lint(&POSSIBLY_UNRESOLVED_REFERENCE); registry.register_lint(&SUBCLASS_OF_FINAL_CLASS); + registry.register_lint(&TYPE_ASSERTION_FAILURE); registry.register_lint(&TOO_MANY_POSITIONAL_ARGUMENTS); registry.register_lint(&UNDEFINED_REVEAL); registry.register_lint(&UNKNOWN_ARGUMENT); @@ -546,6 +547,28 @@ declare_lint! { } } +declare_lint! { + /// ## What it does + /// Checks for `assert_type()` calls where the actual type + /// is not the same as the asserted type. + /// + /// ## Why is this bad? + /// `assert_type()` allows confirming the inferred type of a certain value. + /// + /// ## Example + /// + /// ```python + /// def _(x: int): + /// assert_type(x, int) # fine + /// assert_type(x, str) # error: Actual type does not match asserted type + /// ``` + pub(crate) static TYPE_ASSERTION_FAILURE = { + summary: "detects failed type assertions", + status: LintStatus::preview("1.0.0"), + default_level: Level::Error, + } +} + declare_lint! { /// ## What it does /// Checks for calls that pass more positional arguments than the callable can accept.