From 0e5413100e3d89d879a349741726f6a1368172cc Mon Sep 17 00:00:00 2001 From: "kai [they]" Date: Sun, 22 Oct 2023 21:48:44 -0700 Subject: [PATCH] Tokenizer (#92) --- .gitignore | 1 + data/sql_input_2.sql | 3 +- data/sql_input_3.sql | 3 +- data/tokenizer_input_0.sql | 22 +++ data/tokenizer_input_1.sql | 1 + data/tokenizer_input_2.sql | 1 + data/tokenizer_input_3.sql | 2 + data/tokenizer_input_4.sql | 2 + data/tokenizer_input_5.sql | 6 + data/tokenizer_input_6.sql | 5 + data/tokenizer_output_0.json | 33 ++++ data/tokenizer_output_1.json | 6 + data/tokenizer_output_2.json | 6 + data/tokenizer_output_3.json | 11 ++ data/tokenizer_output_4.json | 6 + data/tokenizer_output_5.json | 6 + data/tokenizer_output_6.json | 9 ++ snippets/python/sql_script.py | 146 ++++++++++++++++++ snippets/python/tokenizer_script.py | 140 +++++++++++++++++ snippets/ruby/sql_test.rb | 2 +- src/python/{sql_test.py => sql_script.py} | 98 ++++-------- src/python/tokenizer_script.py | 175 ++++++++++++++++++++++ src/ruby/sql_test.rb | 2 +- tasks.py | 16 +- 24 files changed, 626 insertions(+), 76 deletions(-) create mode 100644 data/tokenizer_input_0.sql create mode 100644 data/tokenizer_input_1.sql create mode 100644 data/tokenizer_input_2.sql create mode 100644 data/tokenizer_input_3.sql create mode 100644 data/tokenizer_input_4.sql create mode 100644 data/tokenizer_input_5.sql create mode 100644 data/tokenizer_input_6.sql create mode 100644 data/tokenizer_output_0.json create mode 100644 data/tokenizer_output_1.json create mode 100644 data/tokenizer_output_2.json create mode 100644 data/tokenizer_output_3.json create mode 100644 data/tokenizer_output_4.json create mode 100644 data/tokenizer_output_5.json create mode 100644 data/tokenizer_output_6.json create mode 100644 snippets/python/sql_script.py create mode 100644 snippets/python/tokenizer_script.py rename src/python/{sql_test.py => sql_script.py} (58%) create mode 100644 src/python/tokenizer_script.py diff --git a/.gitignore b/.gitignore index 7bad37d..e150687 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ data/sort_by_* data/sorted_by_* src/rust/target/ __pycache__/ +**/.mypy_cache/ diff --git a/data/sql_input_2.sql b/data/sql_input_2.sql index 89b50fb..05b3e48 100644 --- a/data/sql_input_2.sql +++ b/data/sql_input_2.sql @@ -14,8 +14,9 @@ VALUES ('San Francisco', 852469, -8); INSERT INTO city (name, population, timezone) VALUES ('New York', 8405837, -5); -SELECT +SELECT ( name, population, timezone +) FROM city; diff --git a/data/sql_input_3.sql b/data/sql_input_3.sql index df39515..1fa2bba 100644 --- a/data/sql_input_3.sql +++ b/data/sql_input_3.sql @@ -14,8 +14,9 @@ VALUES ('San Francisco', -8); INSERT INTO city (name, population) VALUES ('New York', 8405837); -SELECT +SELECT ( name, population, timezone +) FROM city; diff --git a/data/tokenizer_input_0.sql b/data/tokenizer_input_0.sql new file mode 100644 index 0000000..c81010e --- /dev/null +++ b/data/tokenizer_input_0.sql @@ -0,0 +1,22 @@ +CREATE TABLE town (); + +CREATE TABLE city ( + name VARCHAR, + population INT, + timezone INT +); + +INSERT INTO city (name, population, timezone) +VALUES ('San Francisco', 852469, -8); + +INSERT INTO city (name, population) +VALUES ('New York', 8405837); + +SELECT ( + name, + population, + timezone +) +FROM city; + +SELECT name FROM city; diff --git a/data/tokenizer_input_1.sql b/data/tokenizer_input_1.sql new file mode 100644 index 0000000..a7853e8 --- /dev/null +++ b/data/tokenizer_input_1.sql @@ -0,0 +1 @@ +CREATE TABLE town (); diff --git a/data/tokenizer_input_2.sql b/data/tokenizer_input_2.sql new file mode 100644 index 0000000..89f31e9 --- /dev/null +++ b/data/tokenizer_input_2.sql @@ -0,0 +1 @@ +SELECT name FROM city; diff --git a/data/tokenizer_input_3.sql b/data/tokenizer_input_3.sql new file mode 100644 index 0000000..d7c2fac --- /dev/null +++ b/data/tokenizer_input_3.sql @@ -0,0 +1,2 @@ +INSERT INTO city (name, population, timezone) +VALUES ('San Francisco', 852469, -8); diff --git a/data/tokenizer_input_4.sql b/data/tokenizer_input_4.sql new file mode 100644 index 0000000..ea612e4 --- /dev/null +++ b/data/tokenizer_input_4.sql @@ -0,0 +1,2 @@ +INSERT INTO items (type) +VALUES ('"d"r"u"g"s"'); diff --git a/data/tokenizer_input_5.sql b/data/tokenizer_input_5.sql new file mode 100644 index 0000000..80c48c0 --- /dev/null +++ b/data/tokenizer_input_5.sql @@ -0,0 +1,6 @@ +SELECT ( + name, + population, + timezone +) +FROM city; diff --git a/data/tokenizer_input_6.sql b/data/tokenizer_input_6.sql new file mode 100644 index 0000000..4c2abf2 --- /dev/null +++ b/data/tokenizer_input_6.sql @@ -0,0 +1,5 @@ +CREATE TABLE city ( + name VARCHAR, + population INT, + timezone INT +); diff --git a/data/tokenizer_output_0.json b/data/tokenizer_output_0.json new file mode 100644 index 0000000..d345248 --- /dev/null +++ b/data/tokenizer_output_0.json @@ -0,0 +1,33 @@ +[ + { "worker": "CREATE TABLE", "args": ["town", []] }, + { + "worker": "CREATE TABLE", + "args": [ + "city", + ["name", "VARCHAR", "population", "INT", "timezone", "INT"] + ] + }, + { + "worker": "INSERT INTO", + "args": [ + "city", + ["name", "population", "timezone"], + "VALUES", + ["'San Francisco'", "852469", "-8"] + ] + }, + { + "worker": "INSERT INTO", + "args": [ + "city", + ["name", "population"], + "VALUES", + ["'New York'", "8405837"] + ] + }, + { + "worker": "SELECT", + "args": [["name", "population", "timezone"], "FROM", "city"] + }, + { "worker": "SELECT", "args": ["name", "FROM", "city"] } +] diff --git a/data/tokenizer_output_1.json b/data/tokenizer_output_1.json new file mode 100644 index 0000000..77a29e7 --- /dev/null +++ b/data/tokenizer_output_1.json @@ -0,0 +1,6 @@ +[ + { + "worker": "CREATE TABLE", + "args": ["town", []] + } +] diff --git a/data/tokenizer_output_2.json b/data/tokenizer_output_2.json new file mode 100644 index 0000000..ad00f33 --- /dev/null +++ b/data/tokenizer_output_2.json @@ -0,0 +1,6 @@ +[ + { + "worker": "SELECT", + "args": ["name", "FROM", "city"] + } +] diff --git a/data/tokenizer_output_3.json b/data/tokenizer_output_3.json new file mode 100644 index 0000000..c0805c3 --- /dev/null +++ b/data/tokenizer_output_3.json @@ -0,0 +1,11 @@ +[ + { + "worker": "INSERT INTO", + "args": [ + "city", + ["name", "population", "timezone"], + "VALUES", + ["'San Francisco'", "852469", "-8"] + ] + } +] diff --git a/data/tokenizer_output_4.json b/data/tokenizer_output_4.json new file mode 100644 index 0000000..531a7a1 --- /dev/null +++ b/data/tokenizer_output_4.json @@ -0,0 +1,6 @@ +[ + { + "worker": "INSERT INTO", + "args": ["items", ["type"], "VALUES", ["'\"d\"r\"u\"g\"s\"'"]] + } +] diff --git a/data/tokenizer_output_5.json b/data/tokenizer_output_5.json new file mode 100644 index 0000000..831da56 --- /dev/null +++ b/data/tokenizer_output_5.json @@ -0,0 +1,6 @@ +[ + { + "worker": "SELECT", + "args": [["name", "population", "timezone"], "FROM", "city"] + } +] diff --git a/data/tokenizer_output_6.json b/data/tokenizer_output_6.json new file mode 100644 index 0000000..9546031 --- /dev/null +++ b/data/tokenizer_output_6.json @@ -0,0 +1,9 @@ +[ + { + "worker": "CREATE TABLE", + "args": [ + "city", + ["name", "VARCHAR", "population", "INT", "timezone", "INT"] + ] + } +] diff --git a/snippets/python/sql_script.py b/snippets/python/sql_script.py new file mode 100644 index 0000000..f1f2514 --- /dev/null +++ b/snippets/python/sql_script.py @@ -0,0 +1,146 @@ + +import dataclasses +import json +import re +import typing + +import tokenizer_script + + +@dataclasses.dataclass(frozen=True) +class SQLState: + state: dict + + def read_table_meta(self, table_name: str) -> dict: + return self.state.get(table_name, {}).get("metadata", {}) + + def read_table_rows(self, table_name: str) -> list[dict]: + return self.state.get(table_name, {}).get("rows", []) + + def read_information_schema(self) -> list[dict]: + return [data["metadata"] for data in self.state.values()] + + def write_table_meta(self, table_name: str, data: dict): + state = self.state + table = state.get(table_name, {}) + metadata = table.get("metadata", {}) + metadata.update(data) + table["metadata"] = metadata + state[table_name] = table + return self.__class__(state) + + def write_table_rows(self, table_name: str, data: dict): + state = self.state + table = state.get(table_name, {}) + rows = table.get("rows", []) + rows.append(data) + table["rows"] = rows + state[table_name] = table + return self.__class__(state) + + +class SQLType: + @staticmethod + def varchar(data) -> str: + data_str = str(data).strip() + data_str = re.sub(r'^["\']', "", data_str) # leading ' or " + data_str = re.sub(r'["\']$', "", data_str) # trailing ' or " + return data_str + + @staticmethod + def int(data) -> int: + return int(data.strip()) + + +class SQLFunctions: + @staticmethod + def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + table_name = args[0] + + # get columns + columns = {} + columns_str = args[1] + if columns_str: + # fmt: off + columns = { + columns_str[i]: columns_str[i + 1] + for i in range(0, len(columns_str), 2) + } + # fmt: on + + if not state.read_table_meta(table_name): + state = state.write_table_meta( + table_name, + { + "table_name": table_name, + "table_schema": table_schema, + "colums": columns, + }, + ) + return (output, state) + + @staticmethod + def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + table_name = args[0] + keys = args[1] + values = args[3] + key_value_map = dict(zip(keys, values)) + + sql_type_map = { + "VARCHAR": SQLType.varchar, + "INT": SQLType.int, + } + + data = {} + metadata = state.read_table_meta(table_name) + if metadata: + for key, value in key_value_map.items(): + data[key] = sql_type_map[metadata["colums"][key]](value) + state = state.write_table_rows(table_name, data) + + return (output, state) + + @staticmethod + def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: + output: list[dict] = [] + select_columns = args[0] if isinstance(args[0], list) else [args[0]] + from_value = args[2] + + # `information_schema.tables` is a special case + if from_value == "information_schema.tables": + data = state.read_information_schema() + else: + data = state.read_table_rows(from_value) + + output = [] + for datum in data: + # fmt: off + output.append({ + key: datum.get(key) + for key in select_columns + }) + # fmt: on + + return (output, state) + + +def run_sql(input_sql: list[str]) -> list[str]: + output = [] + state = SQLState(state={}) + sql_tokenizer = tokenizer_script.SQLTokenizer( + { + "CREATE TABLE": SQLFunctions.create_table, + "INSERT INTO": SQLFunctions.insert_into, + "SELECT": SQLFunctions.select, + } + ) + sql_token_list = sql_tokenizer.tokenize_sql(input_sql) + + # iterate over each line of sql + for sql_tokens in sql_token_list: + output, state = sql_tokens.worker_func(state, *sql_tokens.args) + + return [json.dumps(output)] + diff --git a/snippets/python/tokenizer_script.py b/snippets/python/tokenizer_script.py new file mode 100644 index 0000000..5ac941c --- /dev/null +++ b/snippets/python/tokenizer_script.py @@ -0,0 +1,140 @@ + +import dataclasses +import json +import os +import typing + + +DEBUG = bool(int(os.getenv("DEBUG", "0"))) + + +@dataclasses.dataclass(frozen=True) +class SQLTokens: + worker_str: str + worker_func: typing.Callable | None + args: list[typing.Any] + + +@dataclasses.dataclass(frozen=True) +class SQLTokenizer: + sql_function_map: dict[str, typing.Callable | None] + + def tokenize_sql(self, sql: list[str]) -> list[SQLTokens]: + # remove comments + sql = [line.strip() for line in sql if not line.startswith("--")] + # re-split on semi-colons, the semi-colons are the true line breaks in SQL + sql = " ".join(sql).split(";") + # remove empty lines + sql = [line.strip() for line in sql if line] + + # get worker strings + worker_strs = [] + worker_funcs = [] + args_strs = [] + for line in sql: + this_worker_str = None + # We sort the SQL function map by its key length, longest first. + # This is a low complexity way to ensure that we can match, for example, + # both `SET SESSION AUTHORIZATION` and `SET`. + # fmt: off + sql_function_map_ordered_keys = sorted([ + key + for key in self.sql_function_map.keys() + ], key=len, reverse=True) + # fmt: on + for key in sql_function_map_ordered_keys: + if line.startswith(key): + this_worker_str = key + worker_strs.append(key) + worker_funcs.append(self.sql_function_map[key]) + args_strs.append(line.replace(key, "").strip()) + break + if this_worker_str is None: + raise ValueError(f"Unknown worker function: {this_worker_str}") + + # tokenize args + args_list: list[list] = [] + for i, sentence in enumerate(args_strs): + args_list.append([]) + word_start: int | None = 0 + inside_list = False + string_start: tuple[int | None, str | None] = (None, None) + for k, letter in enumerate(sentence): + if (string_start[0] is None) and (letter in ["'", '"']): + if DEBUG: + print(f"at letter: {letter}, starting a string") + string_start = (k, letter) + elif (word_start is None) and (letter not in ["(", ")", ",", " "]): + if DEBUG: + print(f"at letter: {letter}, starting a word") + word_start = k + elif (letter == string_start[1]) and (sentence[k - 1] != "\\") and (inside_list): + if DEBUG: + print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}") + string = sentence[string_start[0] : k + 1] + args_list[i][-1].append(string) + string_start = (None, None) + word_start = None + elif (string_start[0] is not None) and (letter == string_start[1]) and (sentence[k - 1] != "\\"): + if DEBUG: + print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}") + string = sentence[string_start[0] : k + 1] + args_list[i].append(string) + string_start = (None, None) + word_start = None + elif (word_start is not None) and (letter in [")"]) and (inside_list) and (string_start[0] is None): + if DEBUG: + print( + f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}" + ) + word = sentence[word_start:k] + args_list[i][-1].append(word) + word_start = None + inside_list = False + elif ( + (word_start is not None) and (letter in [" ", ","]) and (inside_list) and (string_start[0] is None) + ): + if DEBUG: + print( + f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}" + ) + word = sentence[word_start:k] + args_list[i][-1].append(word) + word_start = None + elif (word_start is not None) and (letter in [" ", ")", ","]) and (string_start[0] is None): + if DEBUG: + print(f"at letter: {letter}, adding word: {sentence[word_start:k]}") + word = sentence[word_start:k] + args_list[i].append(word) + word_start = None + elif (word_start is not None) and (k == len(sentence) - 1): + if DEBUG: + print(f"at letter: {letter}, last word: {sentence[word_start:]}") + word = sentence[word_start:] + args_list[i].append(word) + word_start = None + elif letter == "(": + if DEBUG: + print(f"at letter: {letter}, starting a list") + inside_list = True + args_list[i].append([]) + word_start = None + elif (inside_list) and (letter in ")"): + if DEBUG: + print(f"at letter: {letter}, ending list") + inside_list = False + elif word_start is not None: + if DEBUG: + print(f"at letter: {letter}, inside of a word: {sentence[word_start:k]}") + else: + if DEBUG: + print(f"at letter: {letter}") + + return [ + SQLTokens( + worker_str=worker_str, + worker_func=worker_func, + args=args_list, + ) + for worker_str, worker_func, args_list in zip(worker_strs, worker_funcs, args_list) + ] diff --git a/snippets/ruby/sql_test.rb b/snippets/ruby/sql_test.rb index 2a389b6..7d244d4 100644 --- a/snippets/ruby/sql_test.rb +++ b/snippets/ruby/sql_test.rb @@ -99,7 +99,7 @@ def self.select(state, *args) raise 'FROM not found' if from_index.nil? - select_keys = args[1...from_index].join(' ').split(',').map(&:strip) + select_keys = args[1...from_index].join(' ').split(',').map {|s| s.gsub(/[()]/, '')}.map(&:strip) from_value = args[from_index + 1] data = if from_value == 'information_schema.tables' diff --git a/src/python/sql_test.py b/src/python/sql_script.py similarity index 58% rename from src/python/sql_test.py rename to src/python/sql_script.py index 0e6df4f..e9f7fb0 100644 --- a/src/python/sql_test.py +++ b/src/python/sql_script.py @@ -11,8 +11,11 @@ import dataclasses import json +import re import typing +import tokenizer_script + @dataclasses.dataclass(frozen=True) class SQLState: @@ -50,10 +53,8 @@ class SQLType: @staticmethod def varchar(data) -> str: data_str = str(data).strip() - if data_str.startswith("'") or data_str.startswith('"'): - data_str = data_str[1:] - if data_str.endswith("'") or data_str.endswith('"'): - data_str = data_str[:-1] + data_str = re.sub(r'^["\']', "", data_str) # leading ' or " + data_str = re.sub(r'["\']$', "", data_str) # trailing ' or " return data_str @staticmethod @@ -61,26 +62,20 @@ def int(data) -> int: return int(data.strip()) -sql_type_map = { - "VARCHAR": SQLType.varchar, - "INT": SQLType.int, -} - - class SQLFunctions: @staticmethod def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]: output: list[dict] = [] - table_name = args[2] + table_name = args[0] # get columns columns = {} - columns_str = " ".join(args[3:]).replace("(", "").replace(")", "").strip() + columns_str = args[1] if columns_str: # fmt: off columns = { - column.strip().split(" ")[0]: column.strip().split(" ")[1] - for column in columns_str.split(",") + columns_str[i]: columns_str[i + 1] + for i in range(0, len(columns_str), 2) } # fmt: on @@ -98,23 +93,19 @@ def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[ @staticmethod def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: output: list[dict] = [] - table_name = args[2] - - values_index = None - for i, arg in enumerate(args): - if arg == "VALUES": - values_index = i - if values_index is None: - raise ValueError("VALUES not found") - - keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",") - keys = [key.strip() for key in keys] - values = " ".join(args[values_index + 1 :]).replace("(", "").replace(")", "").split(",") - values = [value.strip() for value in values] + table_name = args[0] + keys = args[1] + values = args[3] key_value_map = dict(zip(keys, values)) + sql_type_map = { + "VARCHAR": SQLType.varchar, + "INT": SQLType.int, + } + data = {} - if metadata := state.read_table_meta(table_name): + metadata = state.read_table_meta(table_name) + if metadata: for key, value in key_value_map.items(): data[key] = sql_type_map[metadata["colums"][key]](value) state = state.write_table_rows(table_name, data) @@ -124,23 +115,8 @@ def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: @staticmethod def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: output: list[dict] = [] - - from_index = None - where_index = None - for i, arg in enumerate(args): - if arg == "FROM": - from_index = i - if arg == "WHERE": - where_index = i - if from_index is None: - raise ValueError("FROM not found") - - # get select keys by getting the slice of args before FROM - select_keys = " ".join(args[1:from_index]).split(",") - select_keys = [key.strip() for key in select_keys] - - # get where keys by getting the slice of args after WHERE - from_value = args[from_index + 1] + select_columns = args[0] if isinstance(args[0], list) else [args[0]] + from_value = args[2] # `information_schema.tables` is a special case if from_value == "information_schema.tables": @@ -153,38 +129,28 @@ def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: # fmt: off output.append({ key: datum.get(key) - for key in select_keys + for key in select_columns }) # fmt: on return (output, state) -sql_function_map: dict[str, typing.Callable] = { - "CREATE TABLE": SQLFunctions.create_table, - "SELECT": SQLFunctions.select, - "INSERT INTO": SQLFunctions.insert_into, -} - - def run_sql(input_sql: list[str]) -> list[str]: output = [] state = SQLState(state={}) - - # remove comments - input_sql = [line.strip() for line in input_sql if not line.startswith("--")] - - # re-split on semi-colons - input_sql = " ".join(input_sql).split(";") + sql_tokenizer = tokenizer_script.SQLTokenizer( + { + "CREATE TABLE": SQLFunctions.create_table, + "INSERT INTO": SQLFunctions.insert_into, + "SELECT": SQLFunctions.select, + } + ) + sql_token_list = sql_tokenizer.tokenize_sql(input_sql) # iterate over each line of sql - for line in input_sql: - words = line.split(" ") - for i in reversed(range(len(words) + 1)): - key = " ".join(words[:i]).strip() - if func := sql_function_map.get(key): - output, state = func(state, *[word for word in words if word]) - break + for sql_tokens in sql_token_list: + output, state = sql_tokens.worker_func(state, *sql_tokens.args) return [json.dumps(output)] diff --git a/src/python/tokenizer_script.py b/src/python/tokenizer_script.py new file mode 100644 index 0000000..62e55a6 --- /dev/null +++ b/src/python/tokenizer_script.py @@ -0,0 +1,175 @@ +import helpers + +######################## +# business logic start # +######################## + + +import dataclasses +import json +import os +import typing + + +DEBUG = bool(int(os.getenv("DEBUG", "0"))) + + +@dataclasses.dataclass(frozen=True) +class SQLTokens: + worker_str: str + worker_func: typing.Callable | None + args: list[typing.Any] + + +@dataclasses.dataclass(frozen=True) +class SQLTokenizer: + sql_function_map: dict[str, typing.Callable | None] + + def tokenize_sql(self, sql: list[str]) -> list[SQLTokens]: + # remove comments + sql = [line.strip() for line in sql if not line.startswith("--")] + # re-split on semi-colons, the semi-colons are the true line breaks in SQL + sql = " ".join(sql).split(";") + # remove empty lines + sql = [line.strip() for line in sql if line] + + # get worker strings + worker_strs = [] + worker_funcs = [] + args_strs = [] + for line in sql: + this_worker_str = None + # We sort the SQL function map by its key length, longest first. + # This is a low complexity way to ensure that we can match, for example, + # both `SET SESSION AUTHORIZATION` and `SET`. + # fmt: off + sql_function_map_ordered_keys = sorted([ + key + for key in self.sql_function_map.keys() + ], key=len, reverse=True) + # fmt: on + for key in sql_function_map_ordered_keys: + if line.startswith(key): + this_worker_str = key + worker_strs.append(key) + worker_funcs.append(self.sql_function_map[key]) + args_strs.append(line.replace(key, "").strip()) + break + if this_worker_str is None: + raise ValueError(f"Unknown worker function: {this_worker_str}") + + # tokenize args + args_list: list[list] = [] + for i, sentence in enumerate(args_strs): + args_list.append([]) + word_start: int | None = 0 + inside_list = False + string_start: tuple[int | None, str | None] = (None, None) + for k, letter in enumerate(sentence): + if (string_start[0] is None) and (letter in ["'", '"']): + if DEBUG: + print(f"at letter: {letter}, starting a string") + string_start = (k, letter) + elif (word_start is None) and (letter not in ["(", ")", ",", " "]): + if DEBUG: + print(f"at letter: {letter}, starting a word") + word_start = k + elif (letter == string_start[1]) and (sentence[k - 1] != "\\") and (inside_list): + if DEBUG: + print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}") + string = sentence[string_start[0] : k + 1] + args_list[i][-1].append(string) + string_start = (None, None) + word_start = None + elif (string_start[0] is not None) and (letter == string_start[1]) and (sentence[k - 1] != "\\"): + if DEBUG: + print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}") + string = sentence[string_start[0] : k + 1] + args_list[i].append(string) + string_start = (None, None) + word_start = None + elif (word_start is not None) and (letter in [")"]) and (inside_list) and (string_start[0] is None): + if DEBUG: + print( + f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}" + ) + word = sentence[word_start:k] + args_list[i][-1].append(word) + word_start = None + inside_list = False + elif ( + (word_start is not None) and (letter in [" ", ","]) and (inside_list) and (string_start[0] is None) + ): + if DEBUG: + print( + f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}" + ) + word = sentence[word_start:k] + args_list[i][-1].append(word) + word_start = None + elif (word_start is not None) and (letter in [" ", ")", ","]) and (string_start[0] is None): + if DEBUG: + print(f"at letter: {letter}, adding word: {sentence[word_start:k]}") + word = sentence[word_start:k] + args_list[i].append(word) + word_start = None + elif (word_start is not None) and (k == len(sentence) - 1): + if DEBUG: + print(f"at letter: {letter}, last word: {sentence[word_start:]}") + word = sentence[word_start:] + args_list[i].append(word) + word_start = None + elif letter == "(": + if DEBUG: + print(f"at letter: {letter}, starting a list") + inside_list = True + args_list[i].append([]) + word_start = None + elif (inside_list) and (letter in ")"): + if DEBUG: + print(f"at letter: {letter}, ending list") + inside_list = False + elif word_start is not None: + if DEBUG: + print(f"at letter: {letter}, inside of a word: {sentence[word_start:k]}") + else: + if DEBUG: + print(f"at letter: {letter}") + + return [ + SQLTokens( + worker_str=worker_str, + worker_func=worker_func, + args=args_list, + ) + for worker_str, worker_func, args_list in zip(worker_strs, worker_funcs, args_list) + ] + + ###################### + # business logic end # + ###################### + + def tokenize_sql_to_json(self, sql: list[str]) -> list[str]: + return [ + json.dumps( + [ + { + "worker": sql_tokens.worker_str, + "args": sql_tokens.args, + } + for sql_tokens in self.tokenize_sql(sql) + ] + ) + ] + + +if __name__ == "__main__": + helpers.run( + SQLTokenizer( + { + "CREATE TABLE": None, + "INSERT INTO": None, + "SELECT": None, + } + ).tokenize_sql_to_json + ) diff --git a/src/ruby/sql_test.rb b/src/ruby/sql_test.rb index 491bd38..cee0df3 100644 --- a/src/ruby/sql_test.rb +++ b/src/ruby/sql_test.rb @@ -107,7 +107,7 @@ def self.select(state, *args) raise 'FROM not found' if from_index.nil? - select_keys = args[1...from_index].join(' ').split(',').map(&:strip) + select_keys = args[1...from_index].join(' ').split(',').map {|s| s.gsub(/[()]/, '')}.map(&:strip) from_value = args[from_index + 1] data = if from_value == 'information_schema.tables' diff --git a/tasks.py b/tasks.py index bfdf839..d79cb92 100644 --- a/tasks.py +++ b/tasks.py @@ -32,7 +32,7 @@ def clean_string(inp): """remove unwanted characters from a string""" if inp: inp = inp.lower().strip() - for element in ["-", "_", "test", "sort"]: + for element in ["-", "_", "test", "script", "sort"]: inp = inp.replace(element, "") return inp @@ -62,12 +62,15 @@ def data(self): class TestRunnerContexts: base_directory = os.getcwd() data_folder_path = "./data" + debug = False ctxs: list[TestRunnerContext] = [] # We use these strings to mark the start and end of the important part of our scripts snippet_start_text = "business logic start" snippet_end_text = "business logic end" - def __init__(self, language, input_data_index) -> None: + def __init__(self, language, input_data_index, debug=False) -> None: + self.debug = debug + # get the language specific config with open(f"{self.base_directory}/config.yml", "r", encoding="utf-8") as obj: data = obj.read() @@ -176,6 +179,7 @@ def generate(self, language, config, script_path, input_file_path): # construct ending call args docker_run_test_list += [ + f"-e=DEBUG={1 if self.debug else 0}", f"-e=INPUT_PATH={input_file_path}", f"-e=OUTPUT_PATH={script_output_file_path}", config["dockerImage"], @@ -244,9 +248,9 @@ class TestRunner: invoke: invoke.Context ctxs: TestRunnerContexts - def __init__(self, _invoke, language, input_data_index) -> None: + def __init__(self, _invoke, language, input_data_index, debug=False) -> None: self.invoke = _invoke - self.ctxs = TestRunnerContexts(language, input_data_index) + self.ctxs = TestRunnerContexts(language, input_data_index, debug=debug) def run_tests(self, input_script): # run every test @@ -392,10 +396,10 @@ def show_results(self): @invoke.task -def test(ctx: invoke.Context, language, input_script, input_data_index, snippets=False): +def test(ctx: invoke.Context, language, input_script, input_data_index, snippets=False, debug=False): # language is the programming language to run scripts in # input_script is the name of a script you want to run - runner = TestRunner(ctx, language, input_data_index) + runner = TestRunner(ctx, language, input_data_index, debug=debug) runner.run_tests(input_script) if snippets: runner.generate_snippets(input_script)