Skip to content

Commit

Permalink
Add support for custom error messages in static_assert
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp committed Jan 8, 2025
1 parent 2d85dc1 commit 599dfe8
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ from knot_extensions import static_assert
# error: [static-assert-error]
static_assert()

# error: [too-many-positional-arguments] "Too many positional arguments to function `static_assert`: expected 1, got 3"
# error: [too-many-positional-arguments] "Too many positional arguments to function `static_assert`: expected 2, got 3"
static_assert(True, 2, 3)
```

Expand Down
22 changes: 22 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/type_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,28 @@ class InvalidBoolDunder:
static_assert(InvalidBoolDunder())
```

### Custom error messages

Alternatively, users can provide custom error messages:

```py
from knot_extensions import static_assert

# error: "Static assertion error: I really want this to be true"
static_assert(1 + 1 == 3, "I really want this to be true")

error_message = "A custom message "
error_message += "constructed from multiple string literals"
# error: "Static assertion error: A custom message constructed from multiple string literals"
static_assert(False, error_message)

# There are limitations to what we can still infer as a string literal. In those cases,
# we simply fall back to the default message.
shouted_message = "A custom message".upper()
# error: "Static assertion error: argument evaluates to `False`"
static_assert(False, shouted_message)
```

## Type predicates

The `knot_extensions` module also provides predicates to test various properties of types. These are
Expand Down
12 changes: 11 additions & 1 deletion crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,13 @@ impl<'db> Type<'db> {
}
}

pub fn into_string_literal(self) -> Option<StringLiteralType<'db>> {
match self {
Type::StringLiteral(string_literal) => Some(string_literal),
_ => None,
}
}

#[track_caller]
pub fn expect_int_literal(self) -> i64 {
self.into_int_literal()
Expand Down Expand Up @@ -1779,16 +1786,19 @@ impl<'db> Type<'db> {
CallOutcome::revealed(binding, revealed_ty)
}
Some(KnownFunction::StaticAssert) => {
if let Some(parameter_ty) = binding.one_parameter_ty() {
if let Some((parameter_ty, message)) = binding.two_parameter_tys() {
let truthiness = parameter_ty.bool(db);

if truthiness.is_always_true() {
CallOutcome::callable(binding)
} else {
let message = message.into_string_literal().map(|s| &**s.value(db));

CallOutcome::StaticAssertionError {
binding,
parameter_ty,
truthiness,
message,
}
}
} else {
Expand Down
18 changes: 18 additions & 0 deletions crates/red_knot_python_semantic/src/types/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(super) enum CallOutcome<'db> {
binding: CallBinding<'db>,
parameter_ty: Type<'db>,
truthiness: Truthiness,
message: Option<&'db str>,
},
}

Expand Down Expand Up @@ -258,10 +259,24 @@ impl<'db> CallOutcome<'db> {
}),
}
}
CallOutcome::StaticAssertionError {
binding: _,
parameter_ty: _,
truthiness: _,
message: Some(message),
} => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!("Static assertion error: {message}"),
);
Ok(Type::Unknown)
}
CallOutcome::StaticAssertionError {
binding,
parameter_ty: Type::BooleanLiteral(false),
truthiness: _,
message: _,
} => {
binding.report_diagnostics(context, node);
context.report_lint(
Expand All @@ -276,6 +291,7 @@ impl<'db> CallOutcome<'db> {
binding,
parameter_ty,
truthiness,
message: _,
} if truthiness.is_always_false() => {
binding.report_diagnostics(context, node);
context.report_lint(
Expand All @@ -293,6 +309,7 @@ impl<'db> CallOutcome<'db> {
binding,
parameter_ty,
truthiness,
message: _,
} if truthiness.is_ambiguous() => {
binding.report_diagnostics(context, node);
context.report_lint(
Expand All @@ -310,6 +327,7 @@ impl<'db> CallOutcome<'db> {
binding,
parameter_ty,
truthiness: _,
message: _,
} => {
binding.report_diagnostics(context, node);
context.report_lint(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import _SpecialForm, Any
from typing import _SpecialForm, Any, LiteralString

# Special operations
def static_assert(condition: object) -> None: ...
def static_assert(condition: object, msg: LiteralString | None = None) -> None: ...

# Types
Unknown = object()
Expand Down

0 comments on commit 599dfe8

Please sign in to comment.