Skip to content

Commit

Permalink
snippets
Browse files Browse the repository at this point in the history
  • Loading branch information
coilysiren committed Oct 22, 2023
1 parent bdbe926 commit 36cebbb
Showing 1 changed file with 154 additions and 56 deletions.
210 changes: 154 additions & 56 deletions snippets/python/sql_test.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,119 @@

import dataclasses
import json
import typing


@dataclasses.dataclass(frozen=True)
class SQLState:
state: dict

def read_table_meta(self, table_name: str) -> dict:
return self.state.get(table_name, {}).get("metadata", {})

def read_table_rows(self, table_name: str) -> list[dict]:
return self.state.get(table_name, {}).get("rows", [])

def read_information_schema(self) -> list[dict]:
return [data["metadata"] for data in self.state.values()]

def write_table_meta(self, table_name: str, data: dict):
state = self.state
table = state.get(table_name, {})
metadata = table.get("metadata", {})
metadata.update(data)
table["metadata"] = metadata
state[table_name] = table
return self.__class__(state)

def write_table_rows(self, table_name: str, data: dict):
state = self.state
table = state.get(table_name, {})
rows = table.get("rows", [])
rows.append(data)
table["rows"] = rows
state[table_name] = table
return self.__class__(state)


class SQLType:
@staticmethod
def varchar(data) -> str:
data_str = str(data).strip()
if data_str.startswith("'") or data_str.startswith('"'):
data_str = data_str[1:]
if data_str.endswith("'") or data_str.endswith('"'):
data_str = data_str[:-1]
return data_str

@staticmethod
def int(data) -> int:
return int(data.strip())


sql_type_map = {
"VARCHAR": SQLType.varchar,
"INT": SQLType.int,
}


class SQLFunctions:
@staticmethod
def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]:
output: list[dict] = []
table_name = args[2]

# get columns
columns = {}
columns_str = " ".join(args[3:]).replace("(", "").replace(")", "").strip()
if columns_str:
# fmt: off
columns = {
column.strip().split(" ")[0]: column.strip().split(" ")[1]
for column in columns_str.split(",")
}
# fmt: on

class SQL:
data: dict = {}

def __init__(self) -> None:
self.data = {}

def information_schema_tables(self) -> list[dict]:
return [data["metadata"] for data in self.data.values()]

def create_table(self, *args, table_schema="public") -> dict:
table_name = args[2]
if not self.data.get(table_name):
self.data[table_name] = {
"metadata": {
if not state.read_table_meta(table_name):
state = state.write_table_meta(
table_name,
{
"table_name": table_name,
"table_schema": table_schema,
"colums": columns,
},
}
return {}
)
return (output, state)

@staticmethod
def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
output: list[dict] = []
table_name = args[2]

values_index = None
for i, arg in enumerate(args):
if arg == "VALUES":
values_index = i
if values_index is None:
raise ValueError("VALUES not found")

keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",")
keys = [key.strip() for key in keys]
values = " ".join(args[values_index + 1 :]).replace("(", "").replace(")", "").split(",")
values = [value.strip() for value in values]
key_value_map = dict(zip(keys, values))

create_table.sql = "CREATE TABLE"
data = {}
if metadata := state.read_table_meta(table_name):
for key, value in key_value_map.items():
data[key] = sql_type_map[metadata["colums"][key]](value)
state = state.write_table_rows(table_name, data)

def select(self, *args) -> dict:
output = {}
return (output, state)

@staticmethod
def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
output: list[dict] = []

from_index = None
where_index = None
Expand All @@ -34,49 +122,59 @@ def select(self, *args) -> dict:
from_index = i
if arg == "WHERE":
where_index = i
if from_index is None:
raise ValueError("FROM not found")

# get select keys by getting the slice of args before FROM
select_keys = " ".join(args[1:from_index]).split(",")
select_keys = [key.strip() for key in select_keys]

# get where keys by getting the slice of args after WHERE
from_value = args[from_index + 1]

# consider "information_schema.tables" a special case until
# we figure out why its so different from the others
# `information_schema.tables` is a special case
if from_value == "information_schema.tables":
target = self.information_schema_tables()

# fmt: off
output = {
key: [
value for data in target
for key, value in data.items()
if key in select_keys
]
for key in select_keys
}
# fmt: on

return output

select.sql = "SELECT"

sql_map = {
create_table.sql: create_table,
select.sql: select,
}

def run(self, input_sql: list[str]) -> list[str]:
output = {}

for line in input_sql:
if not line.startswith("--"):
words = line.split(" ")
for i in reversed(range(len(words))):
key = " ".join(words[:i])
if func := self.sql_map.get(key):
output = func(self, *words)
break

return [json.dumps(output)]
data = state.read_information_schema()
else:
data = state.read_table_rows(from_value)

output = []
for datum in data:
# fmt: off
output.append({
key: datum.get(key)
for key in select_keys
})
# fmt: on

return (output, state)


sql_function_map: dict[str, typing.Callable] = {
"CREATE TABLE": SQLFunctions.create_table,
"SELECT": SQLFunctions.select,
"INSERT INTO": SQLFunctions.insert_into,
}


def run_sql(input_sql: list[str]) -> list[str]:
output = []
state = SQLState(state={})

# remove comments
input_sql = [line.strip() for line in input_sql if not line.startswith("--")]

# re-split on semi-colons
input_sql = " ".join(input_sql).split(";")

# iterate over each line of sql
for line in input_sql:
words = line.split(" ")
for i in reversed(range(len(words) + 1)):
key = " ".join(words[:i]).strip()
if func := sql_function_map.get(key):
output, state = func(state, *[word for word in words if word])
break

return [json.dumps(output)]

0 comments on commit 36cebbb

Please sign in to comment.