From b5d8b8f237ebb5a36dc3472d4f3e33e5880220d1 Mon Sep 17 00:00:00 2001 From: InSyncWithFoo Date: Thu, 9 Jan 2025 22:13:39 +0000 Subject: [PATCH] Per review --- .../mdtest/directives/assert_type.md | 31 +++++++-------- crates/red_knot_python_semantic/src/types.rs | 39 ++++++++++++------- .../src/types/infer.rs | 25 +++++++----- 3 files changed, 55 insertions(+), 40 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md b/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md index 54d1f98fe816c0..9c62f025f81b56 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md +++ b/crates/red_knot_python_semantic/resources/mdtest/directives/assert_type.md @@ -46,8 +46,7 @@ def _(a: type[int], b: type[Any]): # The expression constructing the type is not taken into account def _(a: type[int]): - # TODO: Infer the second argument as a type expression - assert_type(a, Type[int]) # error: [type-assertion-failure] + assert_type(a, Type[int]) # fine ``` ## Gradual types @@ -69,14 +68,12 @@ def _(a: Unknown, b: Any): def _(a: type[Unknown], b: type[Any]): # TODO: Should be `type[Unknown]` reveal_type(a) # revealed: @Todo(unsupported type[X] special form) - reveal_type(b) # revealed: type[Any] + # TODO: Should be fine + assert_type(a, type[Any]) # error: [type-assertion-failure] - # TODO: Infer the second argument as a type expression - # Should be fine - assert_type(a, type[Unknown]) # error: [type-assertion-failure] - # TODO: Infer the second argument as a type expression - # Should be fine - assert_type(b, type[Any]) # error: [type-assertion-failure] + reveal_type(b) # revealed: type[Any] + # TODO: Should be fine + assert_type(b, type[Unknown]) # error: [type-assertion-failure] ``` ## Tuples @@ -86,19 +83,21 @@ Tuple types with the same elements are the same. ```py from typing_extensions import assert_type +from knot_extensions import Unknown + def _(a: tuple[int, str, bytes]): - # TODO: Infer the second argument as a type expression - # Should be fine - assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure] + assert_type(a, tuple[int, str, bytes]) # fine assert_type(a, tuple[int, str]) # error: [type-assertion-failure] assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure] assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure] -def _(a: tuple[Any, ...]): - # TODO: Infer the second argument as a type expression - # Should be fine - assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure] +def _(a: tuple[Any, ...], b: tuple[Unknown, ...]): + assert_type(a, tuple[Any, ...]) # fine + assert_type(a, tuple[Unknown, ...]) # fine + + assert_type(b, tuple[Any, ...]) # fine + assert_type(b, tuple[Unknown, ...]) # fine ``` ## Unions diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 025b3e003b810f..2925f89a2facb5 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -2005,6 +2005,14 @@ impl<'db> Type<'db> { CallOutcome::callable(binding) } + Some(KnownFunction::AssertType) => { + let Some((_, asserted_ty)) = binding.two_parameter_tys() else { + return CallOutcome::callable(binding); + }; + + CallOutcome::asserted(binding, asserted_ty) + } + _ => CallOutcome::callable(binding), } } @@ -3401,17 +3409,23 @@ impl KnownFunction { /// Whether or not a particular function takes type expression as arguments, i.e. should /// the argument of a call like `f(int)` be interpreted as the type int (true) or as the /// type of the expression `int`, i.e. `Literal[int]` (false). - const fn takes_type_expression_arguments(self) -> bool { - matches!( - self, - KnownFunction::IsEquivalentTo - | KnownFunction::IsSubtypeOf - | KnownFunction::IsAssignableTo - | KnownFunction::IsDisjointFrom - | KnownFunction::IsFullyStatic - | KnownFunction::IsSingleton - | KnownFunction::IsSingleValued - ) + const fn takes_type_expression_arguments(self) -> u32 { + const ALL_VALUES: u32 = 0b0; + const SINGLE_TYPE: u32 = 0b1; + const TYPE_TYPE: u32 = 0b11; + const VALUE_TYPE: u32 = 0b10; + + match self { + KnownFunction::IsEquivalentTo => TYPE_TYPE, + KnownFunction::IsSubtypeOf => TYPE_TYPE, + KnownFunction::IsAssignableTo => TYPE_TYPE, + KnownFunction::IsDisjointFrom => TYPE_TYPE, + KnownFunction::IsFullyStatic => SINGLE_TYPE, + KnownFunction::IsSingleton => SINGLE_TYPE, + KnownFunction::IsSingleValued => SINGLE_TYPE, + KnownFunction::AssertType => VALUE_TYPE, + _ => ALL_VALUES, + } } } @@ -4712,7 +4726,6 @@ pub(crate) mod tests { #[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))] #[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))] #[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))] - #[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)] #[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfBuiltinClass("object"))] // TODO: Compare unions/intersections with different orders // #[test_case( @@ -4749,9 +4762,7 @@ pub(crate) mod tests { } #[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfAny)] - #[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfUnknown)] #[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfAny)] - #[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfUnknown)] #[test_case( Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")]) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 915eb79e30febb..f818e4804cff21 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -932,7 +932,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_type_parameters(type_params); if let Some(arguments) = class.arguments.as_deref() { - self.infer_arguments(arguments, false); + self.infer_arguments(arguments, 0b0); } } @@ -2539,17 +2539,21 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_arguments<'a>( &mut self, arguments: &'a ast::Arguments, - infer_as_type_expressions: bool, + infer_as_type_expressions: u32, ) -> CallArguments<'a, 'db> { - let infer_argument_type = if infer_as_type_expressions { - Self::infer_type_expression - } else { - Self::infer_expression - }; - arguments .arguments_source_order() - .map(|arg_or_keyword| { + .enumerate() + .map(|(index, arg_or_keyword)| { + // TODO: Remove this once we have proper overload matching + let infer_argument_type = if index < u32::BITS as usize + && infer_as_type_expressions & (1 << index) != 0 + { + Self::infer_type_expression + } else { + Self::infer_expression + }; + match arg_or_keyword { ast::ArgOrKeyword::Arg(arg) => match arg { ast::Expr::Starred(ast::ExprStarred { @@ -3095,7 +3099,8 @@ impl<'db> TypeInferenceBuilder<'db> { let infer_arguments_as_type_expressions = function_type .into_function_literal() .and_then(|f| f.known(self.db())) - .is_some_and(KnownFunction::takes_type_expression_arguments); + .map(KnownFunction::takes_type_expression_arguments) + .unwrap_or(0b0); let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions); function_type