Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow passing variables that may exist in the loaded SQL files #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lineagex/LineageXNoConn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ def __init__(
dialect: str = "postgres",
target_schema: Optional[str] = "public",
search_path_schema: Optional[str] = "public",
variables:Optional[dict] = {},
) -> None:
self.output_dict = {}
self.parsed = 0
self.target_schema = target_schema
search_path_schema = [x.strip() for x in search_path_schema.split(",")]
search_path_schema.append(target_schema)
s2d = SqlToDict(path=sql, schema_list=search_path_schema, dialect=dialect)
s2d = SqlToDict(path=sql, schema_list=search_path_schema, dialect=dialect, variables=variables)
self.sql_files_dict = s2d.sql_files_dict
self.org_sql_files_dict = s2d.org_sql_files_dict
self.dialect = dialect
Expand Down
4 changes: 3 additions & 1 deletion lineagex/LineageXWithConn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
target_schema: Optional[str] = "public",
conn_string: Optional[str] = "",
search_path_schema: Optional[str] = "public",
variables:Optional[dict] = {},
) -> None:
self.parsed = 0
self.not_parsed = 0
Expand All @@ -35,6 +36,7 @@ def __init__(
self.output_dict = {}
self.conn = self._check_db_connection(conn_string)
self.conn.autocommit = True
self.variables = variables
self._run_table_lineage()

def _run_table_lineage(self) -> None:
Expand Down Expand Up @@ -112,7 +114,7 @@ def _run_table_lineage(self) -> None:
continue
# path or a list of SQL that at least one element contains
else:
self.sql_files_dict = SqlToDict(self.sql, self.schema_list).sql_files_dict
self.sql_files_dict = SqlToDict(self.sql, self.schema_list, variables=self.variables).sql_files_dict
for name, sql in self.sql_files_dict.items():
try:
if name not in self.finished_list:
Expand Down
14 changes: 10 additions & 4 deletions lineagex/SqlToDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import re
from typing import List, Optional, Union

from .utils import find_select, get_files, remove_comments
from .utils import find_select, get_files, remove_comments, load_sql_file, replace_variables

rem_regex = re.compile(r"[^a-zA-Z0-9_.]")


class SqlToDict:
def __init__(
self, path: Optional[Union[List, str]] = "", schema_list: Optional[List] = None, dialect: Optional[str] = "postgres"
self,
path: Optional[Union[List, str]] = "",
schema_list: Optional[List] = None,
dialect: Optional[str] = "postgres",
variables:Optional[dict] = {},
) -> None:
self.path = path
self.schema_list = schema_list
Expand All @@ -20,6 +24,7 @@ def __init__(
self.deletion_dict = {}
self.insertion_dict = {}
self.curr_name = ""
self.variables = variables
self._sql_to_dict()
pass

Expand All @@ -30,11 +35,12 @@ def _sql_to_dict(self) -> None:
"""
if isinstance(self.path, list):
for idx, val in enumerate(self.path):
self._preprocess_sql(new_sql=val, file=str(idx), org_sql=val)
code = replace_variables(val, self.variables)
self._preprocess_sql(new_sql=code, file=str(idx), org_sql=code)
else:
self.sql_files = get_files(path=self.path)
for f in self.sql_files:
org_sql = open(f, mode="r", encoding="latin-1").read()
org_sql = load_sql_file(f, self.variables)
new_sql = remove_comments(str1=org_sql)
org_sql_split = list(filter(None, new_sql.split(";")))
# pop DROP IF EXISTS
Expand Down
3 changes: 3 additions & 0 deletions lineagex/lineagex.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
conn_string: Optional[str] = None,
search_path_schema: Optional[str] = "",
dialect: str = "postgres",
variables:Optional[dict] = {},
) -> None:
validate_sql(sql)
target_schema, search_path_schema = validate_schema(
Expand All @@ -65,6 +66,7 @@ def __init__(
target_schema=target_schema,
conn_string=conn_string,
search_path_schema=search_path_schema,
variables=variables,
)
save_js_file()
self.output_dict = lx.output_dict
Expand All @@ -74,6 +76,7 @@ def __init__(
dialect=dialect,
target_schema=target_schema,
search_path_schema=search_path_schema,
variables=variables,
)
save_js_file()
self.output_dict = lx.output_dict
Expand Down
13 changes: 13 additions & 0 deletions lineagex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,5 +256,18 @@ def _produce_html(output_json: Optional[dict] = "") -> None:
)


def replace_variables(text:str, variables:Optional[dict] = {}):
replaced = text
for key, val in variables.items():
replaced = replaced.replace(key, val)

return replaced


def load_sql_file(file:str, variables:Optional[dict]={}):
sql = open(file, mode="r", encoding="latin-1").read()
return replace_variables(sql, variables)


if __name__ == "__main__":
pass