Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adjust how aliases are formatted #4750

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions prqlc/prqlc-parser/src/parser/pr/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ impl ExprKind {
doc_comment: None,
}
}

/// Whether it contains spaces between top-level items.
/// So, for example `sum foo` would be true, but `[foo, bar]` would be
/// false, since the array is self-contained.
pub fn is_multiple_items(&self) -> bool {
match self {
ExprKind::Binary(_) => true,
ExprKind::Func(_) => true,
ExprKind::FuncCall(func_call) if !func_call.args.is_empty() => true,
_ => false,
}
}
}

#[derive(Debug, EnumAsInner, PartialEq, Clone, Serialize, Deserialize, JsonSchema)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: prqlc/prqlc-parser/src/test.rs
expression: "parse_single(r#\"\nfrom employees\nfilter country == \"USA\" # Each line transforms the previous result.\nderive { # This adds columns / variables.\n gross_salary = salary + payroll_tax,\n gross_cost = gross_salary + benefits_cost # Variables can use other variables.\n}\nfilter gross_cost > 0\ngroup {title, country} ( # For each group use a nested pipeline\n aggregate { # Aggregate each group to a single row\n average salary,\n average gross_salary,\n sum salary,\n sum gross_salary,\n average gross_cost,\n sum_gross_cost = sum gross_cost,\n ct = count salary,\n }\n)\nsort sum_gross_cost\nfilter ct > 200\ntake 20\n \"#).unwrap()"
expression: "parse_source(r#\"\nfrom employees\nfilter country == \"USA\" # Each line transforms the previous result.\nderive { # This adds columns / variables.\n gross_salary = salary + payroll_tax,\n gross_cost = gross_salary + benefits_cost # Variables can use other variables.\n}\nfilter gross_cost > 0\ngroup {title, country} ( # For each group use a nested pipeline\n aggregate { # Aggregate each group to a single row\n average salary,\n average gross_salary,\n sum salary,\n sum gross_salary,\n average gross_cost,\n sum_gross_cost = sum gross_cost,\n ct = count salary,\n }\n)\nsort sum_gross_cost\nfilter ct > 200\ntake 20\n \"#).unwrap()"
---
- VarDef:
kind: Main
Expand Down
70 changes: 56 additions & 14 deletions prqlc/prqlc/src/codegen/ast.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashSet;
use std::sync::OnceLock;

use prqlc_parser::parser::pr::FuncCall;
use regex::Regex;

use super::{WriteOpt, WriteSource};
Expand All @@ -26,28 +27,64 @@ impl WriteSource for pr::Expr {
fn write(&self, mut opt: WriteOpt) -> Option<String> {
let mut r = String::new();

// If there's an alias, then the expr can't be unbound on its left (we
// set this before evaluating parentheses)
if self.alias.is_some() {
opt.unbound_expr = false;
}

// We need to know before we print the alias whether our `.kind` needs
// parentheses, because a parethenized value is a single_expr and
// doesn't need a space around the alias' `=`.
let kind_needs_parens = needs_parenthesis(self, &opt.clone());
let single_expr = kind_needs_parens || !self.kind.is_multiple_items();

if let Some(alias) = &self.alias {
r += opt.consume(alias)?;
r += opt.consume(" = ")?;
opt.unbound_expr = false;
r += opt.consume(if single_expr { "=" } else { " = " })?;
}

if !needs_parenthesis(self, &opt) {
r += &self.kind.write(opt.clone())?;
if !kind_needs_parens {
r += &self.kind.write(opt.clone())?
} else {
let value = self.kind.write_between("(", ")", opt.clone());

if let Some(value) = value {
r += &value;
r += &value
} else {
r += &break_line_within_parenthesis(&self.kind, opt)?;
r += &break_line_within_parenthesis(&self.kind, opt.clone())?
}
};

opt.single_term = match &self.kind {
// If this item is within an array / tuple, then the nested item is a
// single term, for example `derive {x = sum listens}`.
pr::ExprKind::Array(_) | pr::ExprKind::Tuple(_) => true,
// If it's a single arg in a function call, that's a single term,
// like `derive x = foo` (but not like `derive x=(sum foo)`)
pr::ExprKind::FuncCall(FuncCall {
args, named_args, ..
}) => dbg!(args.len() + named_args.len() == 1),
_ => false,
};

dbg!((&r, &self, opt));

Some(r)
}
}

fn needs_parenthesis(this: &pr::Expr, opt: &WriteOpt) -> bool {
// If we have an alias, we use parentheses if we contain multiple items, so
// we get `a=(b + c)` instead of `a=b + c`.
//
// The exception if we're within a single term, then this doesn't apply (we
// don't return true here) — so then we get `derive {a = b + c}` rather than
// `derive {a=(b + c)}`.
if this.alias.is_some() && this.kind.is_multiple_items() && !opt.single_term {
return true;
}

if opt.unbound_expr && can_bind_left(&this.kind) {
return true;
}
Expand All @@ -68,12 +105,12 @@ fn needs_parenthesis(this: &pr::Expr, opt: &WriteOpt) -> bool {
// parent has equal binding strength, which means that now associativity of this expr counts
// for example:
// this=(a + b), parent=(a + b) + c
// asoc of + is left
// assoc of + is left
// this is the left operand of parent
// => assoc_matches=true => we don't need parenthesis

// this=(a + b), parent=c + (a + b)
// asoc of + is left
// assoc of + is left
// this is the right operand of parent
// => assoc_matches=false => we need parenthesis
let assoc_matches = match opt.binary_position {
Expand Down Expand Up @@ -474,7 +511,7 @@ impl WriteSource for pr::Stmt {
r += opt.consume(&format!("type {}", type_def.name))?;

if let Some(ty) = &type_def.value {
r += opt.consume(" = ")?;
r += opt.consume("=")?;
r += &ty.kind.write(opt)?;
}
r += "\n";
Expand All @@ -493,7 +530,7 @@ impl WriteSource for pr::Stmt {
r += "import ";
if let Some(alias) = &import_def.alias {
r += &write_ident_part(alias);
r += " = ";
r += "=";
}
r += &import_def.name.write(opt)?;
r += "\n";
Expand Down Expand Up @@ -614,7 +651,7 @@ mod test {
fn test_unary() {
assert_is_formatted(r#"sort {-duration}"#);

assert_is_formatted(r#"select a = -b"#);
assert_is_formatted(r#"select a=-b"#);
assert_is_formatted(r#"join `project-bar.dataset.table` (==col_bax)"#);
}

Expand All @@ -632,15 +669,20 @@ mod test {
#[test]
fn test_func() {
assert_is_formatted(r#"let a = func x y:false -> x and y"#);
// See notes in [WriteOpt::single_term]
assert_is_formatted(r#"derive rounded=(round 2 gross_cost)"#);
assert_is_formatted(r#"aggregate {sum_gross_cost = sum gross_cost}"#);
// Not sure about these? TODO: decide
// assert_is_formatted(r#"derive foo = 2 + 3"#);
assert_is_formatted(r#"derive foo=(2 + 3)"#);
}

#[test]
fn test_simple() {
assert_is_formatted(
r#"
aggregate average_country_salary = (
average salary
)"#,
aggregate average_country_salary=(average salary)
"#,
);
}

Expand Down
20 changes: 17 additions & 3 deletions prqlc/prqlc/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub trait WriteSource {
r += opt.consume(&prefix.to_string())?;
opt.context_strength = 0;
opt.unbound_expr = false;
opt.single_term = true;

let source = self.write(opt.clone())?;
r += opt.consume(&source)?;
Expand All @@ -47,7 +48,7 @@ impl<T: WriteSource> WriteSource for &T {
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct WriteOpt {
/// String to emit as one indentation level
pub tab: &'static str,
Expand All @@ -74,11 +75,23 @@ pub struct WriteOpt {
/// be mistakenly bound into a binary op by appending an unary op.
///
/// For example:
/// `join foo` has an unbound expr, since `join foo ==bar` produced a binary op.
/// `join foo` has an unbound expr, since `join foo ==bar` produced a binary
/// op
/// ...so we need to parenthesize the `==bar`.
pub unbound_expr: bool,

/// True iff the expression is surrounded by spaces and so we prefer to
/// format with parentheses. For example:
/// - the function call in `derive rounded=(round 2 gross_cost)` is
/// formatted without spaces around the `=` because it's a single term.
/// - the function call in `derive foo = 2 + 3` is formatted with spaces
/// around the `=` because it's not a single term
/// - the function call in `aggregate {sum_gross_cost = sum gross_cost}` is
/// formatted with spaces around the `=` because it's not a single term.
pub single_term: bool,
}

#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Position {
Unspecified,
Left,
Expand All @@ -96,6 +109,7 @@ impl Default for WriteOpt {
context_strength: 0,
binary_position: Position::Unspecified,
unbound_expr: false,
single_term: false,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions prqlc/prqlc/src/codegen/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl WriteSource for pr::TyTupleField {

if let Some(name) = name {
r += name;
r += " = ";
r += "=";
}
if let Some(expr) = expr {
r += &expr.write(opt)?;
Expand All @@ -121,7 +121,7 @@ impl WriteSource for UnionVariant<'_> {
let mut r = String::new();
if let Some(name) = &self.0 {
r += name;
r += " = ";
r += "=";
}
opt.consume_width(r.len() as u16);
r += &self.1.write(opt)?;
Expand Down
8 changes: 4 additions & 4 deletions prqlc/prqlc/src/semantic/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ mod test {
assert_snapshot!(eval(r"
{{a_a = 4, a_b = false}, b = 2.1 + 3.6, c = [false, true, false]}
").unwrap(),
@"{{a_a = 4, a_b = false}, b = 5.7, c = [false, true, false]}"
@"{{a_a=4, a_b=false}, b=5.7, c=[false, true, false]}"
);
}

Expand All @@ -507,7 +507,7 @@ mod test {
std.derive {d = 42}
std.filter c
").unwrap(),
@"[{c = true, 7, d = 42}, {c = true, 14, d = 42}]"
@"[{c=true, 7, d=42}, {c=true, 14, d=42}]"
);
}

Expand All @@ -521,7 +521,7 @@ mod test {
]
std.window {d = std.sum b}
").unwrap(),
@"[{d = 4}, {d = 9}, {d = 17}]"
@"[{d=4}, {d=9}, {d=17}]"
);
}

Expand All @@ -535,7 +535,7 @@ mod test {
]
std.columnar {g = std.lag b}
").unwrap(),
@"[{g = null}, {g = 4}, {g = 5}]"
@"[{g=null}, {g=4}, {g=5}]"
);
}
}
4 changes: 2 additions & 2 deletions prqlc/prqlc/tests/integration/project/Project.prql
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
let favorite_artists = [
{artist_id = 120, last_listen = @2023-05-18},
{artist_id = 7, last_listen = @2023-05-16},
{artist_id=120, last_listen=@2023-05-18},
{artist_id=7, last_listen=@2023-05-16},
]

favorite_artists
Expand Down
16 changes: 8 additions & 8 deletions prqlc/prqlc/tests/integration/resolving.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn resolve_basic_01() {
from x
select {a, b}
"#).unwrap(), @r###"
let main <[{a = ?, b = ?}]> = `(Select ...)`
let main <[{a=?, b=?}]> = `(Select ...)`
"###)
}

Expand All @@ -53,7 +53,7 @@ fn resolve_types_01() {
assert_snapshot!(resolve(r#"
type A = int || int
"#).unwrap(), @r###"
type A = int
type A=int
"###)
}

Expand All @@ -62,7 +62,7 @@ fn resolve_types_02() {
assert_snapshot!(resolve(r#"
type A = int || {}
"#).unwrap(), @r###"
type A = int || {}
type A=int || {}
"###)
}

Expand All @@ -71,7 +71,7 @@ fn resolve_types_03() {
assert_snapshot!(resolve(r#"
type A = {a = int, bool} || {b = text, float}
"#).unwrap(), @r###"
type A = {a = int, bool, b = text, float}
type A={a=int, bool, b=text, float}
"###)
}

Expand All @@ -87,9 +87,9 @@ fn resolve_types_04() {
"#,
)
.unwrap(), @r###"
type Status = (
Unpaid = float ||
{reason = text, cancelled_at = timestamp} ||
type Status=(
Unpaid=float ||
{reason=text, cancelled_at=timestamp} ||
)
"###);
}
Expand All @@ -103,7 +103,7 @@ fn resolve_types_05() {
"#,
)
.unwrap(), @r###"
type A = null
type A=null
"###);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ input_file: prqlc/prqlc/tests/integration/queries/aggregation.prql
---
from tracks
filter genre_id == 100
derive empty_name = name == ""
derive empty_name=(name == "")
aggregate {
sum track_id,
concat_array name,
all empty_name,
any empty_name,
}

Loading
Loading