Skip to content

Commit

Permalink
Check temporary table with qualified name
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenn committed Feb 25, 2024
1 parent 8cb2e10 commit caa629b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion checks.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## TODO

### `CREATE TABLE`
- [ ] qualified (different of `temp`) temporary table
- [X] qualified (different of `temp`) temporary table
```sql
sqlite> ATTACH DATABASE ':memory:' AS mem;
sqlite> CREATE TEMPORARY TABLE mem.x AS SELECT 1;
Expand Down
64 changes: 33 additions & 31 deletions src/lexer/sql/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,43 @@ use crate::parser::{
};

#[test]
fn count_placeholders() -> Result<(), Error> {
let mut parser = Parser::new(b"SELECT ? WHERE 1 = ?");
let ast = parser.next()?.unwrap();
fn count_placeholders() {
let ast = parse_cmd(b"SELECT ? WHERE 1 = ?");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 2);
Ok(())
}

#[test]
fn count_numbered_placeholders() -> Result<(), Error> {
let mut parser = Parser::new(b"SELECT ?1 WHERE 1 = ?2 AND 0 = ?1");
let ast = parser.next()?.unwrap();
fn count_numbered_placeholders() {
let ast = parse_cmd(b"SELECT ?1 WHERE 1 = ?2 AND 0 = ?1");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 2);
Ok(())
}

#[test]
fn count_unused_placeholders() -> Result<(), Error> {
let mut parser = Parser::new(b"SELECT ?1 WHERE 1 = ?3");
let ast = parser.next()?.unwrap();
fn count_unused_placeholders() {
let ast = parse_cmd(b"SELECT ?1 WHERE 1 = ?3");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 3);
Ok(())
}

#[test]
fn count_named_placeholders() -> Result<(), Error> {
let mut parser = Parser::new(b"SELECT :x, :y WHERE 1 = :y");
let ast = parser.next()?.unwrap();
fn count_named_placeholders() {
let ast = parse_cmd(b"SELECT :x, :y WHERE 1 = :y");
let mut info = ParameterInfo::default();
ast.to_tokens(&mut info).unwrap();
assert_eq!(info.count, 2);
assert_eq!(info.names.len(), 2);
assert!(info.names.contains(":x"));
assert!(info.names.contains(":y"));
Ok(())
}

#[test]
fn duplicate_column() {
let mut parser = Parser::new(b"CREATE TABLE t (x TEXT, x TEXT)");
let r = parser.next();
let r = parse(b"CREATE TABLE t (x TEXT, x TEXT)");
let Error::ParserError(ParserError::Custom(msg), _) = r.unwrap_err() else {
panic!("unexpected error type")
};
Expand All @@ -62,8 +53,7 @@ fn duplicate_column() {

#[test]
fn create_table_without_column() {
let mut parser = Parser::new(b"CREATE TABLE t ()");
let r = parser.next();
let r = parse(b"CREATE TABLE t ()");
let Error::ParserError(
ParserError::SyntaxError {
token_type: "RP",
Expand All @@ -82,7 +72,7 @@ fn vtab_args() -> Result<(), Error> {
subject VARCHAR(256) NOT NULL,
body TEXT CHECK(length(body)<10240)
);"#;
let mut parser = Parser::new(sql.as_bytes());
let r = parse_cmd(sql.as_bytes());
let Cmd::Stmt(Stmt::CreateVirtualTable {
tbl_name: QualifiedName {
name: Name(tbl_name),
Expand All @@ -91,7 +81,7 @@ fn vtab_args() -> Result<(), Error> {
module_name: Name(module_name),
args: Some(args),
..
}) = parser.next()?.unwrap()
}) = r
else {
panic!("unexpected AST")
};
Expand All @@ -107,8 +97,8 @@ fn vtab_args() -> Result<(), Error> {
fn only_semicolons_no_statements() {
let sqls = ["", ";", ";;;"];
for sql in sqls.iter() {
let mut parser = Parser::new(sql.as_bytes());
assert_eq!(parser.next().unwrap(), None);
let r = parse(sql.as_bytes());
assert_eq!(r.unwrap(), None);
}
}

Expand Down Expand Up @@ -193,6 +183,15 @@ fn create_table_without_rowid_missing_pk() {
);
}

#[test]
fn create_temporary_table_with_qualified_name() {
expect_parser_err(
b"CREATE TEMPORARY TABLE mem.x AS SELECT 1",
"temporary table name must be unqualified",
);
parse_cmd(b"CREATE TEMPORARY TABLE temp.x AS SELECT 1");
}

#[test]
fn create_strict_table_missing_datatype() {
expect_parser_err(b"CREATE TABLE t (c1) STRICT", "missing datatype for t.c1");
Expand All @@ -208,16 +207,13 @@ fn create_strict_table_unknown_datatype() {

#[test]
fn create_strict_table_generated_column() {
let mut parser = Parser::new(
parse_cmd(
b"CREATE TABLE IF NOT EXISTS transactions (
debit REAL,
credit REAL,
amount REAL GENERATED ALWAYS AS (ifnull(credit, 0.0) -ifnull(debit, 0.0))
) STRICT;
",
) STRICT;",
);
let r = parser.next();
r.unwrap();
}

#[test]
Expand Down Expand Up @@ -289,11 +285,17 @@ fn natural_join_on() {
}

fn expect_parser_err(input: &[u8], error_msg: &str) {
let mut parser = Parser::new(input);
let r = parser.next();
let r = parse(input);
if let Error::ParserError(ParserError::Custom(ref msg), _) = r.unwrap_err() {
assert_eq!(msg, error_msg);
} else {
panic!("unexpected error type")
};
}
fn parse_cmd(input: &[u8]) -> Cmd {
parse(input).unwrap().unwrap()
}
fn parse(input: &[u8]) -> Result<Option<Cmd>, Error> {
let mut parser = Parser::new(input);
parser.next()
}
16 changes: 15 additions & 1 deletion src/parser/ast/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,21 @@ impl Stmt {
}
Ok(())
}
Stmt::CreateTable { tbl_name, body, .. } => body.check(tbl_name),
Stmt::CreateTable {
temporary,
tbl_name,
body,
..
} => {
if *temporary {
if let Some(ref db_name) = tbl_name.db_name {
if !"TEMP".eq_ignore_ascii_case(&db_name.0) {
return Err(custom_err!("temporary table name must be unqualified"));
}
}
}
body.check(tbl_name)
}
Stmt::CreateView {
view_name,
columns: Some(columns),
Expand Down

0 comments on commit caa629b

Please sign in to comment.