From bfc7925594e8974899305467cbd72aaef4821da9 Mon Sep 17 00:00:00 2001 From: ltipton Date: Tue, 30 Jul 2024 13:45:39 -0400 Subject: [PATCH 1/2] feat: allow passing variables that may exist in the SQL files * Does a simple replace call on each variable --- lineagex/LineageXNoConn.py | 3 ++- lineagex/LineageXWithConn.py | 4 +++- lineagex/SqlToDict.py | 11 ++++++++--- lineagex/lineagex.py | 3 +++ lineagex/utils.py | 13 +++++++++++++ 5 files changed, 29 insertions(+), 5 deletions(-) diff --git a/lineagex/LineageXNoConn.py b/lineagex/LineageXNoConn.py index ffe499f..d9a2a77 100644 --- a/lineagex/LineageXNoConn.py +++ b/lineagex/LineageXNoConn.py @@ -35,13 +35,14 @@ def __init__( dialect: str = "postgres", target_schema: Optional[str] = "public", search_path_schema: Optional[str] = "public", + variables:Optional[dict] = {}, ) -> None: self.output_dict = {} self.parsed = 0 self.target_schema = target_schema search_path_schema = [x.strip() for x in search_path_schema.split(",")] search_path_schema.append(target_schema) - s2d = SqlToDict(path=sql, schema_list=search_path_schema, dialect=dialect) + s2d = SqlToDict(path=sql, schema_list=search_path_schema, dialect=dialect, variables=variables) self.sql_files_dict = s2d.sql_files_dict self.org_sql_files_dict = s2d.org_sql_files_dict self.dialect = dialect diff --git a/lineagex/LineageXWithConn.py b/lineagex/LineageXWithConn.py index 8aeb37b..a9940b5 100644 --- a/lineagex/LineageXWithConn.py +++ b/lineagex/LineageXWithConn.py @@ -18,6 +18,7 @@ def __init__( target_schema: Optional[str] = "public", conn_string: Optional[str] = "", search_path_schema: Optional[str] = "public", + variables:Optional[dict] = {}, ) -> None: self.parsed = 0 self.not_parsed = 0 @@ -35,6 +36,7 @@ def __init__( self.output_dict = {} self.conn = self._check_db_connection(conn_string) self.conn.autocommit = True + self.variables = variables self._run_table_lineage() def _run_table_lineage(self) -> None: @@ -112,7 +114,7 @@ def _run_table_lineage(self) -> None: continue # path or a list of SQL that at least one element contains else: - self.sql_files_dict = SqlToDict(self.sql, self.schema_list).sql_files_dict + self.sql_files_dict = SqlToDict(self.sql, self.schema_list, variables=self.variables).sql_files_dict for name, sql in self.sql_files_dict.items(): try: if name not in self.finished_list: diff --git a/lineagex/SqlToDict.py b/lineagex/SqlToDict.py index 4bb8322..09dcb2a 100644 --- a/lineagex/SqlToDict.py +++ b/lineagex/SqlToDict.py @@ -2,14 +2,18 @@ import re from typing import List, Optional, Union -from .utils import find_select, get_files, remove_comments +from .utils import find_select, get_files, remove_comments, load_sql_file rem_regex = re.compile(r"[^a-zA-Z0-9_.]") class SqlToDict: def __init__( - self, path: Optional[Union[List, str]] = "", schema_list: Optional[List] = None, dialect: Optional[str] = "postgres" + self, + path: Optional[Union[List, str]] = "", + schema_list: Optional[List] = None, + dialect: Optional[str] = "postgres", + variables:Optional[dict] = {}, ) -> None: self.path = path self.schema_list = schema_list @@ -20,6 +24,7 @@ def __init__( self.deletion_dict = {} self.insertion_dict = {} self.curr_name = "" + self.variables = variables self._sql_to_dict() pass @@ -34,7 +39,7 @@ def _sql_to_dict(self) -> None: else: self.sql_files = get_files(path=self.path) for f in self.sql_files: - org_sql = open(f, mode="r", encoding="latin-1").read() + org_sql = load_sql_file(f, self.variables) new_sql = remove_comments(str1=org_sql) org_sql_split = list(filter(None, new_sql.split(";"))) # pop DROP IF EXISTS diff --git a/lineagex/lineagex.py b/lineagex/lineagex.py index 33bd02a..38f0b9f 100644 --- a/lineagex/lineagex.py +++ b/lineagex/lineagex.py @@ -50,6 +50,7 @@ def __init__( conn_string: Optional[str] = None, search_path_schema: Optional[str] = "", dialect: str = "postgres", + variables:Optional[dict] = {}, ) -> None: validate_sql(sql) target_schema, search_path_schema = validate_schema( @@ -65,6 +66,7 @@ def __init__( target_schema=target_schema, conn_string=conn_string, search_path_schema=search_path_schema, + variables=variables, ) save_js_file() self.output_dict = lx.output_dict @@ -74,6 +76,7 @@ def __init__( dialect=dialect, target_schema=target_schema, search_path_schema=search_path_schema, + variables=variables, ) save_js_file() self.output_dict = lx.output_dict diff --git a/lineagex/utils.py b/lineagex/utils.py index 527f018..ecd28a6 100644 --- a/lineagex/utils.py +++ b/lineagex/utils.py @@ -256,5 +256,18 @@ def _produce_html(output_json: Optional[dict] = "") -> None: ) +def replace_variables(text:str, variables:Optional[dict] = {}): + replaced = text + for key, val in variables.items(): + replaced = replaced.replace(key, val) + + return replaced + + +def load_sql_file(file:str, variables:Optional[dict]={}): + sql = open(file, mode="r", encoding="latin-1").read() + return replace_variables(sql, variables) + + if __name__ == "__main__": pass From 457c2f63a2a3869f8c83ca40a76648b446fb2462 Mon Sep 17 00:00:00 2001 From: ltipton Date: Tue, 30 Jul 2024 16:59:07 -0400 Subject: [PATCH 2/2] fix: add replace variables to passed in sql code --- lineagex/SqlToDict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lineagex/SqlToDict.py b/lineagex/SqlToDict.py index 09dcb2a..a54b023 100644 --- a/lineagex/SqlToDict.py +++ b/lineagex/SqlToDict.py @@ -2,7 +2,7 @@ import re from typing import List, Optional, Union -from .utils import find_select, get_files, remove_comments, load_sql_file +from .utils import find_select, get_files, remove_comments, load_sql_file, replace_variables rem_regex = re.compile(r"[^a-zA-Z0-9_.]") @@ -35,7 +35,8 @@ def _sql_to_dict(self) -> None: """ if isinstance(self.path, list): for idx, val in enumerate(self.path): - self._preprocess_sql(new_sql=val, file=str(idx), org_sql=val) + code = replace_variables(val, self.variables) + self._preprocess_sql(new_sql=code, file=str(idx), org_sql=code) else: self.sql_files = get_files(path=self.path) for f in self.sql_files: