-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm_eval_out_of_ten.py
103 lines (86 loc) · 3.88 KB
/
llm_eval_out_of_ten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Import required python packages
from astrapy.db import AstraDBCollection, AstraDB
from datasets import Dataset
from pprint import pprint
import pandas as pd
import requests
import json
from openai import OpenAI
import ftfy
import evaluation_prompts
import re
import os
from dotenv import load_dotenv
load_dotenv()
# AstraDB connection information
token = os.getenv("token")
endpoint = os.getenv("endpoint")
collection_name = "test_instructions"
astra_db = AstraDB(token=token, api_endpoint=endpoint)
collection = AstraDBCollection(collection_name=collection_name, astra_db=astra_db)
# API key for OpenAI
OPENAI_API_KEY = os.getenv("openai_key")
# Client for OpenAI API
client = OpenAI(api_key = OPENAI_API_KEY)
nextPageState = ""
raw_dataset = []
expected_columns = ['_id','instruction', 'input', 'output', 'original_response', 'fine_tuned_response']
def check_expected_columns(raw_instruction):
if all(column in raw_instruction for column in expected_columns):
return True
else:
return False
while nextPageState != None:
if nextPageState == "":
data = collection.find()
nextPageState = data['data']['nextPageState']
raw_instructions = [instruction for instruction in data['data']['documents'] if check_expected_columns(instruction)]
raw_dataset.extend(raw_instructions)
else:
data = collection.find(options={"pageState":nextPageState}, sort = None)
nextPageState = data['data']['nextPageState']
raw_instructions = [instruction for instruction in data['data']['documents'] if check_expected_columns(instruction)]
raw_dataset.extend(raw_instructions)
# Turns separated instruction dicts from Astra into a dataset of combined instructions
def create_prompts(record):
start = "Read the Instruction below and provide an answer."
question = f"### INSTRUCTION:\n{record['instruction']}\n\n"
response = f"### Context:\n{record['input']}\n"
original_answer = f"### Response:\n {record['original_response']}\n\n"
fine_tuned_answer = f"### Response:\n {record['fine_tuned_response']}\n\n"
end = "### End"
original_parts = [part for part in [start, question, response, original_answer, end] if part]
fine_tuned_parts = [part for part in [start, question, response, fine_tuned_answer, end] if part]
original_formatted_prompt = "\n\n".join(original_parts).replace('\\n', '\n')
fine_tuned_formatted_prompt = "\n\n".join(fine_tuned_parts).replace('\\n', '\n')
record["original_text"] = original_formatted_prompt
record["fine_tuned_text"] = fine_tuned_formatted_prompt
return record
combined_dataset = list(map(create_prompts, raw_dataset))
dataframe = pd.DataFrame(data=combined_dataset, dtype='string')
dataframe.info()
dataset = Dataset.from_pandas(dataframe)
idx_min = 0
idx_max = 500
partial_dataset = dataset.filter(lambda example, idx: idx >= idx_min and idx < idx_max, with_indices=True)
def generate_rating( instruction_and_response):
prompt = evaluation_prompts.evaluation_prompt + "\n" + instruction_and_response
response = client.chat.completions.create(
model = "gpt-3.5-turbo-0125",
messages = [
{"role": "system", "content": evaluation_prompts.system_prompt},
{ "role": "user", "content": prompt }
]
)
return response
for row in partial_dataset:
original_rating = int(re.sub('\\D', '', generate_rating(row['original_text']).choices[0].message.content)[0])
fine_tuned_rating = int(re.sub('\\D', '', generate_rating(row['fine_tuned_text']).choices[0].message.content)[0])
#print(original_rating)
#print(fine_tuned_rating)
print("Original: "+str(original_rating)+"\t Fine Tuned: "+str(fine_tuned_rating))
collection.update_one(
filter={"_id": row['_id']},
update={"$set": {"original_response": original_rating,
"fine_tuned_response": fine_tuned_rating}},
)