Skip to content

Commit

Permalink
Move special forms handling into infer.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp committed Jan 7, 2025
1 parent a14af21 commit 4aaacee
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 148 deletions.
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/resources/mdtest/type_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,9 @@ from knot_extensions import Not, Unknown, TypeOf
# error: "Special form `knot_extensions.Unknown` expected no type parameter"
u: Unknown[str]

# error: "Expected 1 type argument, got 2"
# error: "Special form `knot_extensions.Not` expected exactly one type parameter"
n: Not[int, str]

# error: "Expected 1 type argument, got 3"
# error: "Special form `knot_extensions.TypeOf` expected exactly one type parameter"
t: TypeOf[int, str, bytes]
```
32 changes: 16 additions & 16 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ mod narrow;
mod signatures;
mod slots;
mod string_annotation;
mod type_api;
mod unpacker;

#[cfg(test)]
Expand Down Expand Up @@ -1821,34 +1820,35 @@ impl<'db> Type<'db> {
}

Some(KnownFunction::IsEquivalentTo) => {
let [ty_a, ty_b] = binding.parameter_tys() else {
todo!()
let return_ty = match binding.parameter_tys() {
[ty_a, ty_b] => Type::BooleanLiteral(ty_a.is_equivalent_to(db, *ty_b)),
_ => Type::Unknown,
};
binding
.set_return_ty(Type::BooleanLiteral(ty_a.is_equivalent_to(db, *ty_b)));
binding.set_return_ty(return_ty);
CallOutcome::callable(binding)
}
Some(KnownFunction::IsSubtypeOf) => {
let [ty_a, ty_b] = binding.parameter_tys() else {
todo!()
let return_ty = match binding.parameter_tys() {
[ty_a, ty_b] => Type::BooleanLiteral(ty_a.is_subtype_of(db, *ty_b)),
_ => Type::Unknown,
};
binding.set_return_ty(Type::BooleanLiteral(ty_a.is_subtype_of(db, *ty_b)));
binding.set_return_ty(return_ty);
CallOutcome::callable(binding)
}
Some(KnownFunction::IsAssignableTo) => {
let [ty_a, ty_b] = binding.parameter_tys() else {
todo!()
let return_ty = match binding.parameter_tys() {
[ty_a, ty_b] => Type::BooleanLiteral(ty_a.is_assignable_to(db, *ty_b)),
_ => Type::Unknown,
};
binding
.set_return_ty(Type::BooleanLiteral(ty_a.is_assignable_to(db, *ty_b)));
binding.set_return_ty(return_ty);
CallOutcome::callable(binding)
}
Some(KnownFunction::IsDisjointFrom) => {
let [ty_a, ty_b] = binding.parameter_tys() else {
todo!()
let return_ty = match binding.parameter_tys() {
[ty_a, ty_b] => Type::BooleanLiteral(ty_a.is_disjoint_from(db, *ty_b)),
_ => Type::Unknown,
};
binding
.set_return_ty(Type::BooleanLiteral(ty_a.is_disjoint_from(db, *ty_b)));
binding.set_return_ty(return_ty);
CallOutcome::callable(binding)
}
Some(KnownFunction::IsFullyStatic) => {
Expand Down
18 changes: 0 additions & 18 deletions crates/red_knot_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&UNRESOLVED_REFERENCE);
registry.register_lint(&UNSUPPORTED_OPERATOR);
registry.register_lint(&ZERO_STEPSIZE_IN_SLICE);
registry.register_lint(&TYPE_API_WRONG_ARITY);
registry.register_lint(&STATIC_ASSERT_ERROR);

// String annotations
Expand Down Expand Up @@ -680,23 +679,6 @@ declare_lint! {
}
}

declare_lint! {
/// ## What it does
/// Checks for `knot_extensions` type API calls with the wrong number of arguments.
///
/// ## Examples
/// ```python
/// from knot_extensions import is_equivalent_to
///
/// is_equivalent_to(int, str, bool) # error: wrong number of arguments
/// ```
pub(crate) static TYPE_API_WRONG_ARITY = {
summary: "wrong number of arguments",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}

declare_lint! {
/// ## What it does
/// Makes sure that the argument of `static_assert` has a statically-known truthiness.
Expand Down
90 changes: 46 additions & 44 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ use crate::types::diagnostic::{
CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE, INCONSISTENT_MRO, INVALID_BASE,
INVALID_CONTEXT_MANAGER, INVALID_DECLARATION, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM,
INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_ATTRIBUTE, POSSIBLY_UNBOUND_IMPORT,
TYPE_API_WRONG_ARITY, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT,
UNSUPPORTED_OPERATOR,
UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR,
};
use crate::types::mro::MroErrorKind;
use crate::types::type_api::{self, TypeApiArgumentsError, TypeApiSpecialForm};
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, todo_type,
Expand Down Expand Up @@ -4290,40 +4288,6 @@ impl<'db> TypeInferenceBuilder<'db> {
Ok(builder.build())
}

fn infer_type_api_special_form(
&mut self,
special_form: TypeApiSpecialForm,
arguments: &ast::Expr,
) -> Type<'db> {
let db = self.db();

let argument_types = match (special_form, arguments) {
(_, ast::Expr::Tuple(tuple)) => Either::Left(
tuple
.iter()
.map(|element| self.infer_type_expression(element)),
),
(TypeApiSpecialForm::TypeOf, expr) => {
Either::Right(std::iter::once(self.infer_expression(expr)))
}
(_, expr) => Either::Right(std::iter::once(self.infer_type_expression(expr))),
};

type_api::resolve_special_form(db, special_form, argument_types).unwrap_or_else(
|TypeApiArgumentsError { expected, actual }| {
self.context.report_lint(
&TYPE_API_WRONG_ARITY,
arguments.into(),
format_args!(
"Expected {expected} type argument{}, got {actual}",
if expected == 1 { "" } else { "s" },
),
);
Type::Unknown
},
)
}

fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
let ast::ExprSubscript {
range: _,
Expand Down Expand Up @@ -5194,15 +5158,53 @@ impl<'db> TypeInferenceBuilder<'db> {
}

// Type API special forms
KnownInstanceType::Not => {
self.infer_type_api_special_form(TypeApiSpecialForm::Not, arguments_slice)
}
KnownInstanceType::Not => match arguments_slice {
ast::Expr::Tuple(_) => {
self.context.report_lint(
&INVALID_TYPE_FORM,
subscript.into(),
format_args!(
"Special form `{}` expected exactly one type parameter",
known_instance.repr(self.db())
),
);
Type::Unknown
}
_ => {
let argument_type = self.infer_type_expression(arguments_slice);
argument_type.negate(self.db())
}
},
KnownInstanceType::Intersection => {
self.infer_type_api_special_form(TypeApiSpecialForm::Intersection, arguments_slice)
}
KnownInstanceType::TypeOf => {
self.infer_type_api_special_form(TypeApiSpecialForm::TypeOf, arguments_slice)
let elements = match arguments_slice {
ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()),
element => Either::Right(std::iter::once(element)),
};

elements
.fold(IntersectionBuilder::new(self.db()), |builder, element| {
builder.add_positive(self.infer_type_expression(element))
})
.build()
}
KnownInstanceType::TypeOf => match arguments_slice {
ast::Expr::Tuple(_) => {
self.context.report_lint(
&INVALID_TYPE_FORM,
subscript.into(),
format_args!(
"Special form `{}` expected exactly one type parameter",
known_instance.repr(self.db())
),
);
Type::Unknown
}
_ => {
// NB: This calls `infer_expression` instead of `infer_type_expression`.
let argument_type = self.infer_expression(arguments_slice);
argument_type
}
},

// TODO: Generics
KnownInstanceType::ChainMap => {
Expand Down
68 changes: 0 additions & 68 deletions crates/red_knot_python_semantic/src/types/type_api.rs

This file was deleted.

0 comments on commit 4aaacee

Please sign in to comment.