diff --git a/data/sql_input_2.sql b/data/sql_input_2.sql new file mode 100644 index 0000000..89b50fb --- /dev/null +++ b/data/sql_input_2.sql @@ -0,0 +1,21 @@ +-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html +-- https://www.postgresql.org/docs/16/sql-createtable.html +-- https://www.postgresql.org/docs/16/sql-insert.html +-- https://www.postgresql.org/docs/16/sql-select.html +CREATE TABLE city ( + name VARCHAR, + population INT, + timezone INT +); + +INSERT INTO city (name, population, timezone) +VALUES ('San Francisco', 852469, -8); + +INSERT INTO city (name, population, timezone) +VALUES ('New York', 8405837, -5); + +SELECT + name, + population, + timezone +FROM city; diff --git a/data/sql_input_3.sql b/data/sql_input_3.sql new file mode 100644 index 0000000..df39515 --- /dev/null +++ b/data/sql_input_3.sql @@ -0,0 +1,21 @@ +-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html +-- https://www.postgresql.org/docs/16/sql-createtable.html +-- https://www.postgresql.org/docs/16/sql-insert.html +-- https://www.postgresql.org/docs/16/sql-select.html +CREATE TABLE city ( + name VARCHAR, + population INT, + timezone INT +); + +INSERT INTO city (name, timezone) +VALUES ('San Francisco', -8); + +INSERT INTO city (name, population) +VALUES ('New York', 8405837); + +SELECT + name, + population, + timezone +FROM city; diff --git a/data/sql_output_0.json b/data/sql_output_0.json index eaa5d04..4775e53 100644 --- a/data/sql_output_0.json +++ b/data/sql_output_0.json @@ -1,3 +1 @@ -{ - "table_name": ["city"] -} +[{ "table_name": "city" }] diff --git a/data/sql_output_1.json b/data/sql_output_1.json index 86970c5..223e140 100644 --- a/data/sql_output_1.json +++ b/data/sql_output_1.json @@ -1,3 +1 @@ -{ - "table_name": ["city", "town"] -} +[{ "table_name": "city" }, { "table_name": "town" }] diff --git a/data/sql_output_2.json b/data/sql_output_2.json new file mode 100644 index 0000000..e00405c --- /dev/null +++ b/data/sql_output_2.json @@ -0,0 +1,4 @@ +[ + { "name": "San Francisco", "population": 852469, "timezone": -8 }, + { "name": "New York", "population": 8405837, "timezone": -5 } +] diff --git a/data/sql_output_3.json b/data/sql_output_3.json new file mode 100644 index 0000000..e1591f5 --- /dev/null +++ b/data/sql_output_3.json @@ -0,0 +1,4 @@ +[ + { "name": "San Francisco", "population": null, "timezone": -8 }, + { "name": "New York", "population": 8405837, "timezone": null } +] diff --git a/snippets/python/sql_test.py b/snippets/python/sql_test.py index b9d290e..970bfa0 100644 --- a/snippets/python/sql_test.py +++ b/snippets/python/sql_test.py @@ -1,31 +1,119 @@ +import dataclasses import json +import typing + + +@dataclasses.dataclass(frozen=True) +class SQLState: + state: dict + + def read_table_meta(self, table_name: str) -> dict: + return self.state.get(table_name, {}).get("metadata", {}) + + def read_table_rows(self, table_name: str) -> list[dict]: + return self.state.get(table_name, {}).get("rows", []) + + def read_information_schema(self) -> list[dict]: + return [data["metadata"] for data in self.state.values()] + + def write_table_meta(self, table_name: str, data: dict): + state = self.state + table = state.get(table_name, {}) + metadata = table.get("metadata", {}) + metadata.update(data) + table["metadata"] = metadata + state[table_name] = table + return self.__class__(state) + + def write_table_rows(self, table_name: str, data: dict): + state = self.state + table = state.get(table_name, {}) + rows = table.get("rows", []) + rows.append(data) + table["rows"] = rows + state[table_name] = table + return self.__class__(state) + + +class SQLType: + @staticmethod + def varchar(data) -> str: + data_str = str(data).strip() + if data_str.startswith("'") or data_str.startswith('"'): + data_str = data_str[1:] + if data_str.endswith("'") or data_str.endswith('"'): + data_str = data_str[:-1] + return data_str + + @staticmethod + def int(data) -> int: + return int(data.strip()) + + +sql_type_map = { + "VARCHAR": SQLType.varchar, + "INT": SQLType.int, +} + + +class SQLFunctions: + @staticmethod + def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + table_name = args[2] + # get columns + columns = {} + columns_str = " ".join(args[3:]).replace("(", "").replace(")", "").strip() + if columns_str: + # fmt: off + columns = { + column.strip().split(" ")[0]: column.strip().split(" ")[1] + for column in columns_str.split(",") + } + # fmt: on -class SQL: - data: dict = {} - - def __init__(self) -> None: - self.data = {} - - def information_schema_tables(self) -> list[dict]: - return [data["metadata"] for data in self.data.values()] - - def create_table(self, *args, table_schema="public") -> dict: - table_name = args[2] - if not self.data.get(table_name): - self.data[table_name] = { - "metadata": { + if not state.read_table_meta(table_name): + state = state.write_table_meta( + table_name, + { "table_name": table_name, "table_schema": table_schema, + "colums": columns, }, - } - return {} + ) + return (output, state) + + @staticmethod + def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + table_name = args[2] + + values_index = None + for i, arg in enumerate(args): + if arg == "VALUES": + values_index = i + if values_index is None: + raise ValueError("VALUES not found") + + keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",") + keys = [key.strip() for key in keys] + values = " ".join(args[values_index + 1 :]).replace("(", "").replace(")", "").split(",") + values = [value.strip() for value in values] + key_value_map = dict(zip(keys, values)) - create_table.sql = "CREATE TABLE" + data = {} + if metadata := state.read_table_meta(table_name): + for key, value in key_value_map.items(): + data[key] = sql_type_map[metadata["colums"][key]](value) + state = state.write_table_rows(table_name, data) - def select(self, *args) -> dict: - output = {} + return (output, state) + + @staticmethod + def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: + output: list[dict] = [] from_index = None where_index = None @@ -34,49 +122,59 @@ def select(self, *args) -> dict: from_index = i if arg == "WHERE": where_index = i + if from_index is None: + raise ValueError("FROM not found") # get select keys by getting the slice of args before FROM select_keys = " ".join(args[1:from_index]).split(",") + select_keys = [key.strip() for key in select_keys] # get where keys by getting the slice of args after WHERE from_value = args[from_index + 1] - # consider "information_schema.tables" a special case until - # we figure out why its so different from the others + # `information_schema.tables` is a special case if from_value == "information_schema.tables": - target = self.information_schema_tables() - - # fmt: off - output = { - key: [ - value for data in target - for key, value in data.items() - if key in select_keys - ] - for key in select_keys - } - # fmt: on - - return output - - select.sql = "SELECT" - - sql_map = { - create_table.sql: create_table, - select.sql: select, - } - - def run(self, input_sql: list[str]) -> list[str]: - output = {} - - for line in input_sql: - if not line.startswith("--"): - words = line.split(" ") - for i in reversed(range(len(words))): - key = " ".join(words[:i]) - if func := self.sql_map.get(key): - output = func(self, *words) - break - - return [json.dumps(output)] + data = state.read_information_schema() + else: + data = state.read_table_rows(from_value) + + output = [] + for datum in data: + # fmt: off + output.append({ + key: datum.get(key) + for key in select_keys + }) + # fmt: on + + return (output, state) + + +sql_function_map: dict[str, typing.Callable] = { + "CREATE TABLE": SQLFunctions.create_table, + "SELECT": SQLFunctions.select, + "INSERT INTO": SQLFunctions.insert_into, +} + + +def run_sql(input_sql: list[str]) -> list[str]: + output = [] + state = SQLState(state={}) + + # remove comments + input_sql = [line.strip() for line in input_sql if not line.startswith("--")] + + # re-split on semi-colons + input_sql = " ".join(input_sql).split(";") + + # iterate over each line of sql + for line in input_sql: + words = line.split(" ") + for i in reversed(range(len(words) + 1)): + key = " ".join(words[:i]).strip() + if func := sql_function_map.get(key): + output, state = func(state, *[word for word in words if word]) + break + + return [json.dumps(output)] diff --git a/src/python/sql_test.py b/src/python/sql_test.py index 7f41e13..0e6df4f 100644 --- a/src/python/sql_test.py +++ b/src/python/sql_test.py @@ -9,33 +9,121 @@ ######################## +import dataclasses import json +import typing + + +@dataclasses.dataclass(frozen=True) +class SQLState: + state: dict + + def read_table_meta(self, table_name: str) -> dict: + return self.state.get(table_name, {}).get("metadata", {}) + + def read_table_rows(self, table_name: str) -> list[dict]: + return self.state.get(table_name, {}).get("rows", []) + + def read_information_schema(self) -> list[dict]: + return [data["metadata"] for data in self.state.values()] + + def write_table_meta(self, table_name: str, data: dict): + state = self.state + table = state.get(table_name, {}) + metadata = table.get("metadata", {}) + metadata.update(data) + table["metadata"] = metadata + state[table_name] = table + return self.__class__(state) + + def write_table_rows(self, table_name: str, data: dict): + state = self.state + table = state.get(table_name, {}) + rows = table.get("rows", []) + rows.append(data) + table["rows"] = rows + state[table_name] = table + return self.__class__(state) + + +class SQLType: + @staticmethod + def varchar(data) -> str: + data_str = str(data).strip() + if data_str.startswith("'") or data_str.startswith('"'): + data_str = data_str[1:] + if data_str.endswith("'") or data_str.endswith('"'): + data_str = data_str[:-1] + return data_str + + @staticmethod + def int(data) -> int: + return int(data.strip()) + + +sql_type_map = { + "VARCHAR": SQLType.varchar, + "INT": SQLType.int, +} + + +class SQLFunctions: + @staticmethod + def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + table_name = args[2] + # get columns + columns = {} + columns_str = " ".join(args[3:]).replace("(", "").replace(")", "").strip() + if columns_str: + # fmt: off + columns = { + column.strip().split(" ")[0]: column.strip().split(" ")[1] + for column in columns_str.split(",") + } + # fmt: on -class SQL: - data: dict = {} - - def __init__(self) -> None: - self.data = {} - - def information_schema_tables(self) -> list[dict]: - return [data["metadata"] for data in self.data.values()] - - def create_table(self, *args, table_schema="public") -> dict: - table_name = args[2] - if not self.data.get(table_name): - self.data[table_name] = { - "metadata": { + if not state.read_table_meta(table_name): + state = state.write_table_meta( + table_name, + { "table_name": table_name, "table_schema": table_schema, + "colums": columns, }, - } - return {} + ) + return (output, state) + + @staticmethod + def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + table_name = args[2] + + values_index = None + for i, arg in enumerate(args): + if arg == "VALUES": + values_index = i + if values_index is None: + raise ValueError("VALUES not found") + + keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",") + keys = [key.strip() for key in keys] + values = " ".join(args[values_index + 1 :]).replace("(", "").replace(")", "").split(",") + values = [value.strip() for value in values] + key_value_map = dict(zip(keys, values)) - create_table.sql = "CREATE TABLE" + data = {} + if metadata := state.read_table_meta(table_name): + for key, value in key_value_map.items(): + data[key] = sql_type_map[metadata["colums"][key]](value) + state = state.write_table_rows(table_name, data) - def select(self, *args) -> dict: - output = {} + return (output, state) + + @staticmethod + def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: + output: list[dict] = [] from_index = None where_index = None @@ -44,51 +132,61 @@ def select(self, *args) -> dict: from_index = i if arg == "WHERE": where_index = i + if from_index is None: + raise ValueError("FROM not found") # get select keys by getting the slice of args before FROM select_keys = " ".join(args[1:from_index]).split(",") + select_keys = [key.strip() for key in select_keys] # get where keys by getting the slice of args after WHERE from_value = args[from_index + 1] - # consider "information_schema.tables" a special case until - # we figure out why its so different from the others + # `information_schema.tables` is a special case if from_value == "information_schema.tables": - target = self.information_schema_tables() + data = state.read_information_schema() + else: + data = state.read_table_rows(from_value) + + output = [] + for datum in data: + # fmt: off + output.append({ + key: datum.get(key) + for key in select_keys + }) + # fmt: on + + return (output, state) + - # fmt: off - output = { - key: [ - value for data in target - for key, value in data.items() - if key in select_keys - ] - for key in select_keys - } - # fmt: on +sql_function_map: dict[str, typing.Callable] = { + "CREATE TABLE": SQLFunctions.create_table, + "SELECT": SQLFunctions.select, + "INSERT INTO": SQLFunctions.insert_into, +} - return output - select.sql = "SELECT" +def run_sql(input_sql: list[str]) -> list[str]: + output = [] + state = SQLState(state={}) - sql_map = { - create_table.sql: create_table, - select.sql: select, - } + # remove comments + input_sql = [line.strip() for line in input_sql if not line.startswith("--")] - def run(self, input_sql: list[str]) -> list[str]: - output = {} + # re-split on semi-colons + input_sql = " ".join(input_sql).split(";") - for line in input_sql: - if not line.startswith("--"): - words = line.split(" ") - for i in reversed(range(len(words))): - key = " ".join(words[:i]) - if func := self.sql_map.get(key): - output = func(self, *words) - break + # iterate over each line of sql + for line in input_sql: + words = line.split(" ") + for i in reversed(range(len(words) + 1)): + key = " ".join(words[:i]).strip() + if func := sql_function_map.get(key): + output, state = func(state, *[word for word in words if word]) + break - return [json.dumps(output)] + return [json.dumps(output)] ###################### @@ -96,4 +194,4 @@ def run(self, input_sql: list[str]) -> list[str]: ###################### if __name__ == "__main__": - helpers.run(SQL().run) + helpers.run(run_sql) diff --git a/tasks.py b/tasks.py index f104be7..1012564 100644 --- a/tasks.py +++ b/tasks.py @@ -300,7 +300,7 @@ def run_tests(self, input_script): prepared_file_data = json.load(reader) with open(ctx.script_output_file_path, "r", encoding="utf-8") as reader: script_output_file_data = json.load(reader) - unittest.TestCase().assertDictEqual(prepared_file_data, script_output_file_data) + unittest.TestCase().assertListEqual(prepared_file_data, script_output_file_data) self.set_success_status(True) print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded") continue