Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL CREATE TABLE in python #89

Merged
merged 6 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
uses: actions/checkout@v3

- run: pip install invoke pyyaml
- run: invoke test ${{ matrix.language }} any any
- run: invoke test ${{ matrix.language }} any any --snippets
6 changes: 6 additions & 0 deletions data/sql_input_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html
-- https://www.postgresql.org/docs/16/sql-createtable.html
-- https://www.postgresql.org/docs/16/sql-select.html
CREATE TABLE city ();
CREATE TABLE town ();
SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';
3 changes: 3 additions & 0 deletions data/sql_output_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"table_name": ["city", "town"]
}
80 changes: 77 additions & 3 deletions snippets/python/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,81 @@
import json


def run_sql(input_sql: list[str]) -> list[str]:
output = {"table_name": ["city"]}
return [json.dumps(output)]
class SQL:
data: dict = {}

def __init__(self) -> None:
self.data = {}

def information_schema_tables(self) -> list[dict]:
return [data["metadata"] for data in self.data.values()]

def create_table(self, *args, table_schema="public") -> dict:
table_name = args[2]
if not self.data.get(table_name):
self.data[table_name] = {
"metadata": {
"table_name": table_name,
"table_schema": table_schema,
},
}
return {}

create_table.sql = "CREATE TABLE"

def select(self, *args) -> dict:
output = {}

from_index = None
where_index = None
for i, arg in enumerate(args):
if arg == "FROM":
from_index = i
if arg == "WHERE":
where_index = i

# get select keys by getting the slice of args before FROM
select_keys = " ".join(args[1:from_index]).split(",")

# get where keys by getting the slice of args after WHERE
from_value = args[from_index + 1]

# consider "information_schema.tables" a special case until
# we figure out why its so different from the others
if from_value == "information_schema.tables":
target = self.information_schema_tables()

# fmt: off
output = {
key: [
value for data in target
for key, value in data.items()
if key in select_keys
]
for key in select_keys
}
# fmt: on

return output

select.sql = "SELECT"

sql_map = {
create_table.sql: create_table,
select.sql: select,
}

def run(self, input_sql: list[str]) -> list[str]:
output = {}

for line in input_sql:
if not line.startswith("--"):
words = line.split(" ")
for i in reversed(range(len(words))):
key = " ".join(words[:i])
if func := self.sql_map.get(key):
output = func(self, *words)
break

return [json.dumps(output)]

82 changes: 78 additions & 4 deletions src/python/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,88 @@
import json


def run_sql(input_sql: list[str]) -> list[str]:
output = {"table_name": ["city"]}
return [json.dumps(output)]
class SQL:
data: dict = {}

def __init__(self) -> None:
self.data = {}

def information_schema_tables(self) -> list[dict]:
return [data["metadata"] for data in self.data.values()]

def create_table(self, *args, table_schema="public") -> dict:
table_name = args[2]
if not self.data.get(table_name):
self.data[table_name] = {
"metadata": {
"table_name": table_name,
"table_schema": table_schema,
},
}
return {}

create_table.sql = "CREATE TABLE"

def select(self, *args) -> dict:
output = {}

from_index = None
where_index = None
for i, arg in enumerate(args):
if arg == "FROM":
from_index = i
if arg == "WHERE":
where_index = i

# get select keys by getting the slice of args before FROM
select_keys = " ".join(args[1:from_index]).split(",")

# get where keys by getting the slice of args after WHERE
from_value = args[from_index + 1]

# consider "information_schema.tables" a special case until
# we figure out why its so different from the others
if from_value == "information_schema.tables":
target = self.information_schema_tables()

# fmt: off
output = {
key: [
value for data in target
for key, value in data.items()
if key in select_keys
]
for key in select_keys
}
# fmt: on

return output

select.sql = "SELECT"

sql_map = {
create_table.sql: create_table,
select.sql: select,
}

def run(self, input_sql: list[str]) -> list[str]:
output = {}

for line in input_sql:
if not line.startswith("--"):
words = line.split(" ")
for i in reversed(range(len(words))):
key = " ".join(words[:i])
if func := self.sql_map.get(key):
output = func(self, *words)
break

return [json.dumps(output)]


######################
# business logic end #
######################

if __name__ == "__main__":
helpers.run(run_sql)
helpers.run(SQL().run)
18 changes: 9 additions & 9 deletions tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# builtin packages
import unittest
import filecmp
import glob
import os
Expand Down Expand Up @@ -163,6 +164,8 @@ def generate(self, language, config, script_path, input_file_path):
docker_run_test_list = [
"docker",
"run",
"--rm",
f"--name={language}",
f"--volume={self.base_directory}:/workdir",
"-w=/workdir",
]
Expand Down Expand Up @@ -297,13 +300,9 @@ def run_tests(self, input_script):
prepared_file_data = json.load(reader)
with open(ctx.script_output_file_path, "r", encoding="utf-8") as reader:
script_output_file_data = json.load(reader)
if prepared_file_data == script_output_file_data:
self.set_success_status(True)
print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded")
else:
self.set_success_status(False)
print(f"\t🔴 {ctx.script_relative_path} on {ctx.input_file_path} failed, reason:")
print(f"\t\t output file {ctx.script_output_file_name} has does not match the prepared file")
unittest.TestCase().assertDictEqual(prepared_file_data, script_output_file_data)
self.set_success_status(True)
print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded")
continue

# check if the output file matches the prepared file
Expand Down Expand Up @@ -392,12 +391,13 @@ def show_results(self):


@invoke.task
def test(ctx: invoke.Context, language, input_script, input_data_index):
def test(ctx: invoke.Context, language, input_script, input_data_index, snippets=False):
# language is the programming language to run scripts in
# input_script is the name of a script you want to run
runner = TestRunner(ctx, language, input_data_index)
runner.run_tests(input_script)
runner.generate_snippets(input_script)
if snippets:
runner.generate_snippets(input_script)
runner.show_results()


Expand Down