Skip to content

Commit

Permalink
Dynamic column retrieval for metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed Jul 23, 2024
1 parent 0d3e7e5 commit 9d5c594
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
31 changes: 24 additions & 7 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ class EnergyWizardPostgres(EnergyWizardBase):
"""Optional mappings for weird azure names to tiktoken/openai names."""

def __init__(self, db_host, db_port, db_name,
db_schema, db_table, cursor=None, boto_client=None,
db_schema, db_table, meta_columns=None,
cursor=None, boto_client=None,
model=None, token_budget=3500):
"""
Parameters
Expand All @@ -421,6 +422,9 @@ def __init__(self, db_host, db_port, db_name,
db_table : str
Table to query in Postgres database. Necessary columns: id,
chunks, embedding, title, and url.
meta_columns : list
List of metadata columns to retrieve from database. Default
query returns title and url.
cursor : psycopg2.extensions.cursor
PostgreSQL database cursor used to execute queries.
boto_client: botocore.client.BedrockRuntime
Expand All @@ -437,6 +441,10 @@ def __init__(self, db_host, db_port, db_name,

self.db_schema = db_schema
self.db_table = db_table
if meta_columns is None:
self.meta_columns = ['title', 'url']
else:
self.meta_columns = meta_columns

if cursor is None:
db_user = os.getenv("EWIZ_DB_USER")
Expand Down Expand Up @@ -559,20 +567,29 @@ def make_ref_list(self, ids):
"""

placeholders = ', '.join(['%s'] * len(ids))
columns_str = ', '.join([f"{self.db_table}.{c}"
for c in self.meta_columns])

sql_query = (f"SELECT {self.db_table}.title, {self.db_table}.url "
sql_query = (f"SELECT {columns_str} "
f"FROM {self.db_schema}.{self.db_table} "
f"WHERE {self.db_table}.id IN (" + placeholders + ")")

self.cursor.execute(sql_query, ids)

refs = self.cursor.fetchall()

ref_strs = (f"{{\"parentTitle\": \"{item[0]}\", "
f"\"parentUrl\": \"{item[1]}\"}} " for item in refs)
ref_list = []
for item in refs:
ref_dict = {self.meta_columns[i]: item[i]
for i in range(len(self.meta_columns))}
ref_str = "{"
ref_str += ", ".join([f"\"{key}\": \"{value}\""
for key, value in ref_dict.items()])
ref_str += "}"

unique_values = set(ref_strs)
ref_list.append(ref_str)

ref_list = list(unique_values)
unique_values = set(ref_list)
unique_list = list(unique_values)

return ref_list
return unique_list
4 changes: 2 additions & 2 deletions tests/test_wizard_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ def test_postgres():
ref_list = wizard.make_ref_list(ids)

assert len(ref_list) > 0
assert 'parentTitle' in str(ref_list)
assert 'parentUrl' in str(ref_list)
assert 'title' in str(ref_list)
assert 'url' in str(ref_list)
assert 'research-hub.nrel.gov' in str(ref_list)

0 comments on commit 9d5c594

Please sign in to comment.