diff --git a/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md b/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md new file mode 100644 index 00000000000000..fe511cb71f637c --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md @@ -0,0 +1,145 @@ +# `assert_type` + +## Basic + +```py +from typing_extensions import assert_type + +def _(x: int): + assert_type(x, int) # fine + assert_type(x, str) # error: [type-assertion-failure] +``` + +## Narrowing + +The asserted type is checked against the inferred type, not the declared type. + +```toml +[environment] +python-version = "3.10" +``` + +```py +from typing_extensions import assert_type + +def _(x: int | str): + if isinstance(x, int): + reveal_type(x) # revealed: int + assert_type(x, int) # fine +``` + +## Equivalence + +The actual type must match the asserted type precisely. + +```py +from typing import Any, Type, Union +from typing_extensions import assert_type + +# Subtype does not count +def _(x: bool): + assert_type(x, int) # error: [type-assertion-failure] + +def _(a: type[int], b: type[Any]): + assert_type(a, type[Any]) # error: [type-assertion-failure] + assert_type(b, type[int]) # error: [type-assertion-failure] + +# The expression constructing the type is not taken into account +def _(a: type[int]): + assert_type(a, Type[int]) # fine +``` + +## Gradual types + +```py +from typing import Any +from typing_extensions import Literal, assert_type + +from knot_extensions import Unknown + +# Any and Unknown are considered equivalent +def _(a: Unknown, b: Any): + reveal_type(a) # revealed: Unknown + assert_type(a, Any) # fine + + reveal_type(b) # revealed: Any + assert_type(b, Unknown) # fine + +def _(a: type[Unknown], b: type[Any]): + # TODO: Should be `type[Unknown]` + reveal_type(a) # revealed: @Todo(unsupported type[X] special form) + # TODO: Should be fine + assert_type(a, type[Any]) # error: [type-assertion-failure] + + reveal_type(b) # revealed: type[Any] + # TODO: Should be fine + assert_type(b, type[Unknown]) # error: [type-assertion-failure] +``` + +## Tuples + +Tuple types with the same elements are the same. + +```py +from typing_extensions import assert_type + +from knot_extensions import Unknown + +def _(a: tuple[int, str, bytes]): + assert_type(a, tuple[int, str, bytes]) # fine + + assert_type(a, tuple[int, str]) # error: [type-assertion-failure] + assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure] + assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure] + +def _(a: tuple[Any, ...], b: tuple[Unknown, ...]): + assert_type(a, tuple[Any, ...]) # fine + assert_type(a, tuple[Unknown, ...]) # fine + + assert_type(b, tuple[Unknown, ...]) # fine + assert_type(b, tuple[Any, ...]) # fine +``` + +## Unions + +Unions with the same elements are the same, regardless of order. + +```toml +[environment] +python-version = "3.10" +``` + +```py +from typing_extensions import assert_type + +def _(a: str | int): + assert_type(a, str | int) # fine + + # TODO: Order-independent union handling in type equivalence + assert_type(a, int | str) # error: [type-assertion-failure] +``` + +## Intersections + +Intersections are the same when their positive and negative parts are respectively the same, +regardless of order. + +```py +from typing_extensions import assert_type + +from knot_extensions import Intersection, Not + +class A: ... +class B: ... +class C: ... +class D: ... + +def _(a: A): + if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D): + reveal_type(a) # revealed: A & B & ~C & ~D + + assert_type(a, Intersection[A, B, Not[C], Not[D]]) # fine + + # TODO: Order-independent intersection handling in type equivalence + assert_type(a, Intersection[B, A, Not[D], Not[C]]) # error: [type-assertion-failure] +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 5a2c8e895562d3..9e38e7cfae9f92 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,4 +1,5 @@ use std::hash::Hash; +use std::iter; use context::InferContext; use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound}; @@ -1095,6 +1096,87 @@ impl<'db> Type<'db> { ) } + /// Returns true if this type and `other` are gradual equivalent. + /// + /// > Two gradual types `A` and `B` are equivalent + /// > (that is, the same gradual type, not merely consistent with one another) + /// > if and only if all materializations of `A` are also materializations of `B`, + /// > and all materializations of `B` are also materializations of `A`. + /// > + /// > — [Summary of type relations] + /// + /// This powers the `assert_type()` directive. + /// + /// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations + pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { + let equivalent = + |(first, second): (&Type<'db>, &Type<'db>)| first.is_gradual_equivalent_to(db, *second); + + match (self, other) { + (_, _) if self == other => true, + + (Type::Dynamic(_), Type::Dynamic(_)) => true, + + (Type::Instance(instance), Type::SubclassOf(subclass)) + | (Type::SubclassOf(subclass), Type::Instance(instance)) => { + let Some(base_class) = subclass.subclass_of().into_class() else { + return false; + }; + + instance.class.is_known(db, KnownClass::Type) + && base_class.is_known(db, KnownClass::Object) + } + + (Type::SubclassOf(first), Type::SubclassOf(second)) => { + match (first.subclass_of(), second.subclass_of()) { + (first, second) if first == second => true, + (ClassBase::Dynamic(_), ClassBase::Dynamic(_)) => true, + _ => false, + } + } + + (Type::Tuple(first), Type::Tuple(second)) => { + let first_elements = first.elements(db); + let second_elements = second.elements(db); + + first_elements.len() == second_elements.len() + && iter::zip(first_elements, second_elements).all(equivalent) + } + + // TODO: Handle equivalent unions with items in different order + (Type::Union(first), Type::Union(second)) => { + let first_elements = first.elements(db); + let second_elements = second.elements(db); + + if first_elements.len() != second_elements.len() { + return false; + } + + iter::zip(first_elements, second_elements).all(equivalent) + } + + // TODO: Handle equivalent intersections with items in different order + (Type::Intersection(first), Type::Intersection(second)) => { + let first_positive = first.positive(db); + let first_negative = first.negative(db); + + let second_positive = second.positive(db); + let second_negative = second.negative(db); + + if first_positive.len() != second_positive.len() + || first_negative.len() != second_negative.len() + { + return false; + } + + iter::zip(first_positive, second_positive).all(equivalent) + && iter::zip(first_negative, second_negative).all(equivalent) + } + + _ => false, + } + } + /// Return true if this type and `other` have no common elements. /// /// Note: This function aims to have no false positives, but might return @@ -1924,6 +2006,14 @@ impl<'db> Type<'db> { CallOutcome::callable(binding) } + Some(KnownFunction::AssertType) => { + let Some((_, asserted_ty)) = binding.two_parameter_tys() else { + return CallOutcome::callable(binding); + }; + + CallOutcome::asserted(binding, asserted_ty) + } + _ => CallOutcome::callable(binding), } } @@ -3261,6 +3351,9 @@ pub enum KnownFunction { /// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check) NoTypeCheck, + /// `typing(_extensions).assert_type` + AssertType, + /// `knot_extensions.static_assert` StaticAssert, /// `knot_extensions.is_equivalent_to` @@ -3283,18 +3376,7 @@ impl KnownFunction { pub fn constraint_function(self) -> Option { match self { Self::ConstraintFunction(f) => Some(f), - Self::RevealType - | Self::Len - | Self::Final - | Self::NoTypeCheck - | Self::StaticAssert - | Self::IsEquivalentTo - | Self::IsSubtypeOf - | Self::IsAssignableTo - | Self::IsDisjointFrom - | Self::IsFullyStatic - | Self::IsSingleton - | Self::IsSingleValued => None, + _ => None, } } @@ -3316,6 +3398,7 @@ impl KnownFunction { "no_type_check" if definition.is_typing_definition(db) => { Some(KnownFunction::NoTypeCheck) } + "assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType), "static_assert" if definition.is_knot_extensions_definition(db) => { Some(KnownFunction::StaticAssert) } @@ -3345,20 +3428,34 @@ impl KnownFunction { } } - /// Whether or not a particular function takes type expression as arguments, i.e. should - /// the argument of a call like `f(int)` be interpreted as the type int (true) or as the - /// type of the expression `int`, i.e. `Literal[int]` (false). - const fn takes_type_expression_arguments(self) -> bool { - matches!( - self, - KnownFunction::IsEquivalentTo - | KnownFunction::IsSubtypeOf - | KnownFunction::IsAssignableTo - | KnownFunction::IsDisjointFrom - | KnownFunction::IsFullyStatic - | KnownFunction::IsSingleton - | KnownFunction::IsSingleValued - ) + /// Returns a `u32` bitmask specifying whether or not + /// arguments given to a particular function + /// should be interpreted as type expressions or value expressions. + /// + /// The argument is treated as a type expression + /// when the corresponding bit is `1`. + /// The least-significant (right-most) bit corresponds to + /// the argument at the index 0 and so on. + /// + /// For example, `assert_type()` has the bitmask value of `0b10`. + /// This means the second argument is a type expression and the first a value expression. + const fn takes_type_expression_arguments(self) -> u32 { + const ALL_VALUES: u32 = 0b0; + const SINGLE_TYPE: u32 = 0b1; + const TYPE_TYPE: u32 = 0b11; + const VALUE_TYPE: u32 = 0b10; + + match self { + KnownFunction::IsEquivalentTo => TYPE_TYPE, + KnownFunction::IsSubtypeOf => TYPE_TYPE, + KnownFunction::IsAssignableTo => TYPE_TYPE, + KnownFunction::IsDisjointFrom => TYPE_TYPE, + KnownFunction::IsFullyStatic => SINGLE_TYPE, + KnownFunction::IsSingleton => SINGLE_TYPE, + KnownFunction::IsSingleValued => SINGLE_TYPE, + KnownFunction::AssertType => VALUE_TYPE, + _ => ALL_VALUES, + } } } @@ -3681,7 +3778,8 @@ impl<'db> Class<'db> { // does not accept the right arguments CallOutcome::Callable { binding } | CallOutcome::RevealType { binding, .. } - | CallOutcome::StaticAssertionError { binding, .. } => Ok(binding.return_ty()), + | CallOutcome::StaticAssertionError { binding, .. } + | CallOutcome::AssertType { binding, .. } => Ok(binding.return_ty()), }; return return_ty_result.map(|ty| ty.to_meta_type(db)); @@ -4644,6 +4742,82 @@ pub(crate) mod tests { assert!(!from.into_type(&db).is_fully_static(&db)); } + #[test_case(Ty::Todo, Ty::Todo)] + #[test_case(Ty::Any, Ty::Any)] + #[test_case(Ty::Unknown, Ty::Unknown)] + #[test_case(Ty::Any, Ty::Unknown)] + #[test_case(Ty::Todo, Ty::Unknown)] + #[test_case(Ty::Todo, Ty::Any)] + #[test_case(Ty::Never, Ty::Never)] + #[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)] + #[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)] + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))] + #[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))] + #[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))] + #[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))] + #[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfBuiltinClass("object"))] + // TODO: Compare unions/intersections with different orders + // #[test_case( + // Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + // Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]) + // )] + // #[test_case( + // Ty::Intersection { + // pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")], + // neg: vec![Ty::BuiltinInstance("bytes"), Ty::None] + // }, + // Ty::Intersection { + // pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")], + // neg: vec![Ty::None, Ty::BuiltinInstance("bytes")] + // } + // )] + // #[test_case( + // Ty::Intersection { + // pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])], + // neg: vec![Ty::SubclassOfAny] + // }, + // Ty::Intersection { + // pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])], + // neg: vec![Ty::SubclassOfUnknown] + // } + // )] + fn is_gradual_equivalent_to(a: Ty, b: Ty) { + let db = setup_db(); + let a = a.into_type(&db); + let b = b.into_type(&db); + + assert!(a.is_gradual_equivalent_to(&db, b)); + assert!(b.is_gradual_equivalent_to(&db, a)); + } + + #[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfAny)] + #[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfAny)] + #[test_case( + Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")]) + )] + #[test_case( + Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]), + Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")]) + )] + #[test_case( + Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]) + )] + #[test_case( + Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), + Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]) + )] + fn is_not_gradual_equivalent_to(a: Ty, b: Ty) { + let db = setup_db(); + let a = a.into_type(&db); + let b = b.into_type(&db); + + assert!(!a.is_gradual_equivalent_to(&db, b)); + assert!(!b.is_gradual_equivalent_to(&db, a)); + } + #[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")] #[test_case(Ty::IntLiteral(-1))] #[test_case(Ty::StringLiteral("foo"))] diff --git a/crates/red_knot_python_semantic/src/types/call.rs b/crates/red_knot_python_semantic/src/types/call.rs index bffaa3a81c772c..5b8558131f773e 100644 --- a/crates/red_knot_python_semantic/src/types/call.rs +++ b/crates/red_knot_python_semantic/src/types/call.rs @@ -1,5 +1,5 @@ use super::context::InferContext; -use super::diagnostic::CALL_NON_CALLABLE; +use super::diagnostic::{CALL_NON_CALLABLE, TYPE_ASSERTION_FAILURE}; use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder}; use crate::types::diagnostic::STATIC_ASSERT_ERROR; use crate::Db; @@ -44,6 +44,10 @@ pub(super) enum CallOutcome<'db> { binding: CallBinding<'db>, error_kind: StaticAssertionErrorKind<'db>, }, + AssertType { + binding: CallBinding<'db>, + asserted_ty: Type<'db>, + }, } impl<'db> CallOutcome<'db> { @@ -76,6 +80,14 @@ impl<'db> CallOutcome<'db> { } } + /// Create a new `CallOutcome::AssertType` with given revealed and return types. + pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> { + CallOutcome::AssertType { + binding, + asserted_ty, + } + } + /// Get the return type of the call, or `None` if not callable. pub(super) fn return_ty(&self, db: &'db dyn Db) -> Option> { match self { @@ -103,6 +115,10 @@ impl<'db> CallOutcome<'db> { .map(UnionBuilder::build), Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_ty(db), Self::StaticAssertionError { .. } => Some(Type::none(db)), + Self::AssertType { + binding, + asserted_ty: _, + } => Some(binding.return_ty()), } } @@ -309,6 +325,28 @@ impl<'db> CallOutcome<'db> { Ok(Type::unknown()) } + CallOutcome::AssertType { + binding, + asserted_ty, + } => { + let [actual_ty, _asserted] = binding.parameter_tys() else { + return Ok(binding.return_ty()); + }; + + if !actual_ty.is_gradual_equivalent_to(context.db(), *asserted_ty) { + context.report_lint( + &TYPE_ASSERTION_FAILURE, + node, + format_args!( + "Actual type `{}` is not the same as asserted type `{}`", + actual_ty.display(context.db()), + asserted_ty.display(context.db()), + ), + ); + } + + Ok(binding.return_ty()) + } } } } diff --git a/crates/red_knot_python_semantic/src/types/diagnostic.rs b/crates/red_knot_python_semantic/src/types/diagnostic.rs index 52faea6ec8cd31..80e80950252b23 100644 --- a/crates/red_knot_python_semantic/src/types/diagnostic.rs +++ b/crates/red_knot_python_semantic/src/types/diagnostic.rs @@ -49,6 +49,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) { registry.register_lint(&POSSIBLY_UNBOUND_IMPORT); registry.register_lint(&POSSIBLY_UNRESOLVED_REFERENCE); registry.register_lint(&SUBCLASS_OF_FINAL_CLASS); + registry.register_lint(&TYPE_ASSERTION_FAILURE); registry.register_lint(&TOO_MANY_POSITIONAL_ARGUMENTS); registry.register_lint(&UNDEFINED_REVEAL); registry.register_lint(&UNKNOWN_ARGUMENT); @@ -575,6 +576,28 @@ declare_lint! { } } +declare_lint! { + /// ## What it does + /// Checks for `assert_type()` calls where the actual type + /// is not the same as the asserted type. + /// + /// ## Why is this bad? + /// `assert_type()` allows confirming the inferred type of a certain value. + /// + /// ## Example + /// + /// ```python + /// def _(x: int): + /// assert_type(x, int) # fine + /// assert_type(x, str) # error: Actual type does not match asserted type + /// ``` + pub(crate) static TYPE_ASSERTION_FAILURE = { + summary: "detects failed type assertions", + status: LintStatus::preview("1.0.0"), + default_level: Level::Error, + } +} + declare_lint! { /// ## What it does /// Checks for calls that pass more positional arguments than the callable can accept. diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index fd263a883619d1..a3830072e78c3d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -932,7 +932,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_type_parameters(type_params); if let Some(arguments) = class.arguments.as_deref() { - self.infer_arguments(arguments, false); + self.infer_arguments(arguments, 0b0); } } @@ -2539,17 +2539,20 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_arguments<'a>( &mut self, arguments: &'a ast::Arguments, - infer_as_type_expressions: bool, + infer_as_type_expressions: u32, ) -> CallArguments<'a, 'db> { - let infer_argument_type = if infer_as_type_expressions { - Self::infer_type_expression - } else { - Self::infer_expression - }; - arguments .arguments_source_order() - .map(|arg_or_keyword| { + .enumerate() + .map(|(index, arg_or_keyword)| { + let infer_argument_type = if index < u32::BITS as usize + && infer_as_type_expressions & (1 << index) != 0 + { + Self::infer_type_expression + } else { + Self::infer_expression + }; + match arg_or_keyword { ast::ArgOrKeyword::Arg(arg) => match arg { ast::Expr::Starred(ast::ExprStarred { @@ -3095,7 +3098,8 @@ impl<'db> TypeInferenceBuilder<'db> { let infer_arguments_as_type_expressions = function_type .into_function_literal() .and_then(|f| f.known(self.db())) - .is_some_and(KnownFunction::takes_type_expression_arguments); + .map(KnownFunction::takes_type_expression_arguments) + .unwrap_or(0b0); let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions); function_type