Skip to content

Commit

Permalink
Merge pull request #27 from NREL/sp/aws_cohere
Browse files Browse the repository at this point in the history
Alter get_embedding so that it is compatible with cohere models
  • Loading branch information
spodgorny9 committed Aug 9, 2024
2 parents d688041 + 9d7cfa0 commit 265ceb2
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class EnergyWizardPostgres(EnergyWizardBase):
This class is designed as follows:
Vector database: PostgreSQL database accessed using psycopg2.
Query Embedding: AWS titan using boto3
Query Embedding: AWS using boto3
LLM Application: GPT4 via Azure deployment
"""
EMBEDDING_MODEL = 'amazon.titan-embed-text-v1'
Expand Down Expand Up @@ -480,7 +480,7 @@ def __init__(self, db_host, db_port, db_name,

def get_embedding(self, text):
"""Get the 1D array (list) embedding of a text string
as generated by AWS Titan.
as generated by specified AWS model.
Parameters
----------
Expand All @@ -492,10 +492,16 @@ def get_embedding(self, text):
embedding : list
List of float that represents the numerical embedding of the text
"""
model_id = self.EMBEDDING_MODEL

body = json.dumps({"inputText": text, })
if 'cohere' in model_id:
input_type = "search_query"

body = json.dumps({"texts": [text],
"input_type": input_type})
else:
body = json.dumps({"inputText": text, })

model_id = self.EMBEDDING_MODEL
accept = 'application/json'
content_type = 'application/json'

Expand All @@ -507,7 +513,11 @@ def get_embedding(self, text):
)

response_body = json.loads(response['body'].read())
embedding = response_body.get('embedding')

if 'cohere' in model_id:
embedding = response_body.get('embeddings')[0]
else:
embedding = response_body.get('embedding')

return embedding

Expand Down

0 comments on commit 265ceb2

Please sign in to comment.