diff --git a/.github/workflows/config.yml b/.github/workflows/config.yml index aa10317..887c634 100644 --- a/.github/workflows/config.yml +++ b/.github/workflows/config.yml @@ -28,4 +28,4 @@ jobs: uses: actions/checkout@v3 - run: pip install invoke pyyaml - - run: invoke test ${{ matrix.language }} any any + - run: invoke test ${{ matrix.language }} any any --snippets diff --git a/data/sql_input_1.sql b/data/sql_input_1.sql new file mode 100644 index 0000000..bb65967 --- /dev/null +++ b/data/sql_input_1.sql @@ -0,0 +1,6 @@ +-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html +-- https://www.postgresql.org/docs/16/sql-createtable.html +-- https://www.postgresql.org/docs/16/sql-select.html +CREATE TABLE city (); +CREATE TABLE town (); +SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'; diff --git a/data/sql_output_1.json b/data/sql_output_1.json new file mode 100644 index 0000000..86970c5 --- /dev/null +++ b/data/sql_output_1.json @@ -0,0 +1,3 @@ +{ + "table_name": ["city", "town"] +} diff --git a/snippets/python/sql_test.py b/snippets/python/sql_test.py index 1a07517..b9d290e 100644 --- a/snippets/python/sql_test.py +++ b/snippets/python/sql_test.py @@ -2,7 +2,81 @@ import json -def run_sql(input_sql: list[str]) -> list[str]: - output = {"table_name": ["city"]} - return [json.dumps(output)] +class SQL: + data: dict = {} + + def __init__(self) -> None: + self.data = {} + + def information_schema_tables(self) -> list[dict]: + return [data["metadata"] for data in self.data.values()] + + def create_table(self, *args, table_schema="public") -> dict: + table_name = args[2] + if not self.data.get(table_name): + self.data[table_name] = { + "metadata": { + "table_name": table_name, + "table_schema": table_schema, + }, + } + return {} + + create_table.sql = "CREATE TABLE" + + def select(self, *args) -> dict: + output = {} + + from_index = None + where_index = None + for i, arg in enumerate(args): + if arg == "FROM": + from_index = i + if arg == "WHERE": + where_index = i + + # get select keys by getting the slice of args before FROM + select_keys = " ".join(args[1:from_index]).split(",") + + # get where keys by getting the slice of args after WHERE + from_value = args[from_index + 1] + + # consider "information_schema.tables" a special case until + # we figure out why its so different from the others + if from_value == "information_schema.tables": + target = self.information_schema_tables() + + # fmt: off + output = { + key: [ + value for data in target + for key, value in data.items() + if key in select_keys + ] + for key in select_keys + } + # fmt: on + + return output + + select.sql = "SELECT" + + sql_map = { + create_table.sql: create_table, + select.sql: select, + } + + def run(self, input_sql: list[str]) -> list[str]: + output = {} + + for line in input_sql: + if not line.startswith("--"): + words = line.split(" ") + for i in reversed(range(len(words))): + key = " ".join(words[:i]) + if func := self.sql_map.get(key): + output = func(self, *words) + break + + return [json.dumps(output)] diff --git a/src/python/sql_test.py b/src/python/sql_test.py index 3ac0af7..7f41e13 100644 --- a/src/python/sql_test.py +++ b/src/python/sql_test.py @@ -12,9 +12,83 @@ import json -def run_sql(input_sql: list[str]) -> list[str]: - output = {"table_name": ["city"]} - return [json.dumps(output)] +class SQL: + data: dict = {} + + def __init__(self) -> None: + self.data = {} + + def information_schema_tables(self) -> list[dict]: + return [data["metadata"] for data in self.data.values()] + + def create_table(self, *args, table_schema="public") -> dict: + table_name = args[2] + if not self.data.get(table_name): + self.data[table_name] = { + "metadata": { + "table_name": table_name, + "table_schema": table_schema, + }, + } + return {} + + create_table.sql = "CREATE TABLE" + + def select(self, *args) -> dict: + output = {} + + from_index = None + where_index = None + for i, arg in enumerate(args): + if arg == "FROM": + from_index = i + if arg == "WHERE": + where_index = i + + # get select keys by getting the slice of args before FROM + select_keys = " ".join(args[1:from_index]).split(",") + + # get where keys by getting the slice of args after WHERE + from_value = args[from_index + 1] + + # consider "information_schema.tables" a special case until + # we figure out why its so different from the others + if from_value == "information_schema.tables": + target = self.information_schema_tables() + + # fmt: off + output = { + key: [ + value for data in target + for key, value in data.items() + if key in select_keys + ] + for key in select_keys + } + # fmt: on + + return output + + select.sql = "SELECT" + + sql_map = { + create_table.sql: create_table, + select.sql: select, + } + + def run(self, input_sql: list[str]) -> list[str]: + output = {} + + for line in input_sql: + if not line.startswith("--"): + words = line.split(" ") + for i in reversed(range(len(words))): + key = " ".join(words[:i]) + if func := self.sql_map.get(key): + output = func(self, *words) + break + + return [json.dumps(output)] ###################### @@ -22,4 +96,4 @@ def run_sql(input_sql: list[str]) -> list[str]: ###################### if __name__ == "__main__": - helpers.run(run_sql) + helpers.run(SQL().run) diff --git a/tasks.py b/tasks.py index db3215d..f104be7 100644 --- a/tasks.py +++ b/tasks.py @@ -1,4 +1,5 @@ # builtin packages +import unittest import filecmp import glob import os @@ -163,6 +164,8 @@ def generate(self, language, config, script_path, input_file_path): docker_run_test_list = [ "docker", "run", + "--rm", + f"--name={language}", f"--volume={self.base_directory}:/workdir", "-w=/workdir", ] @@ -297,13 +300,9 @@ def run_tests(self, input_script): prepared_file_data = json.load(reader) with open(ctx.script_output_file_path, "r", encoding="utf-8") as reader: script_output_file_data = json.load(reader) - if prepared_file_data == script_output_file_data: - self.set_success_status(True) - print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded") - else: - self.set_success_status(False) - print(f"\t🔴 {ctx.script_relative_path} on {ctx.input_file_path} failed, reason:") - print(f"\t\t output file {ctx.script_output_file_name} has does not match the prepared file") + unittest.TestCase().assertDictEqual(prepared_file_data, script_output_file_data) + self.set_success_status(True) + print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded") continue # check if the output file matches the prepared file @@ -392,12 +391,13 @@ def show_results(self): @invoke.task -def test(ctx: invoke.Context, language, input_script, input_data_index): +def test(ctx: invoke.Context, language, input_script, input_data_index, snippets=False): # language is the programming language to run scripts in # input_script is the name of a script you want to run runner = TestRunner(ctx, language, input_data_index) runner.run_tests(input_script) - runner.generate_snippets(input_script) + if snippets: + runner.generate_snippets(input_script) runner.show_results()