Skip to content

Commit

Permalink
_add_tag to include metadata in text chunks after db retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed Aug 20, 2024
1 parent 265ceb2 commit f9c0db5
Showing 1 changed file with 47 additions and 5 deletions.
52 changes: 47 additions & 5 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ class EnergyWizardPostgres(EnergyWizardBase):
def __init__(self, db_host, db_port, db_name,
db_schema, db_table, meta_columns=None,
cursor=None, boto_client=None,
model=None, token_budget=3500):
model=None, token_budget=3500,
tag=False):
"""
Parameters
----------
Expand Down Expand Up @@ -435,6 +436,9 @@ def __init__(self, db_host, db_port, db_name,
Number of tokens that can be embedded in the prompt. Note that the
default budget for GPT-3.5-Turbo is 4096, but you want to subtract
some tokens to account for the response budget.
tag: bool
Flag to add tag/metadata to text chunks before sending query to
GPT.
"""
boto3 = try_import('boto3')
psycopg2 = try_import('psycopg2')
Expand All @@ -461,6 +465,8 @@ def __init__(self, db_host, db_port, db_name,
else:
self.cursor = cursor

self.tag = tag

if boto_client is None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
Expand Down Expand Up @@ -521,14 +527,42 @@ def get_embedding(self, text):

return embedding

def query_vector_db(self, query, limit=100):
@staticmethod
def _add_tag(meta):
"""Function to add tag/metadata to text strings before
sending query to GPT.
Parameters
----------
meta : tuple
Text values to include in tag (title, authors, year)
Returns
-------
tag : str
Text string containing provided metadata.
"""
title, authors, year = meta
if authors and year:
tag = (f"Title: {title}\n"
f"Authors: {authors}\n"
f"Publication Year: {year}\n\n"
)
else:
tag = f"Title: {title}\n\n"

return tag

def query_vector_db(self, query, probes=25, limit=100):
"""Returns a list of strings and relatednesses, sorted from most
related to least.
Parameters
----------
query : str
Question being asked of GPT
probes: int
Number of lists to search in vector database index.
limit : int
Number of top results to return.
Expand All @@ -545,17 +579,25 @@ def query_vector_db(self, query, limit=100):

query_embedding = self.get_embedding(query)

self.cursor.execute(f"SELECT {self.db_table}.id, "
self.cursor.execute(f"SET LOCAL ivfflat.probes = {probes};"
f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding "
"<=> %s::vector as score "
"<=> %s::vector as score, "
f"{self.db_table}.title, "
f"{self.db_table}.authors, "
f"{self.db_table}.year "
f"FROM {self.db_schema}.{self.db_table} "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )

result = self.cursor.fetchall()

strings = [s[1] for s in result]
if self.tag:
strings = [self._add_tag(s[3:]) + s[1] for s in result]
else:
strings = [s[1] for s in result]

scores = [s[2] for s in result]
best = [s[0] for s in result]

Expand Down

0 comments on commit f9c0db5

Please sign in to comment.