-
Notifications
You must be signed in to change notification settings - Fork 396
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os, json, itertools, bisect, gc | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | ||
import transformers | ||
import torch | ||
from accelerate import Accelerator | ||
import accelerate | ||
import time | ||
|
||
model = None | ||
tokenizer = None | ||
generator = None | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
from csv_reader import csv_prompter | ||
|
||
def load_model(model_name, eight_bit=0, device_map="auto"): | ||
global model, tokenizer, generator | ||
|
||
print("Loading " + model_name + "...") | ||
|
||
if device_map == "zero": | ||
device_map = "balanced_low_0" | ||
|
||
# config | ||
gpu_count = torch.cuda.device_count() | ||
print('gpu_count', gpu_count) | ||
|
||
tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name) | ||
model = transformers.LLaMAForCausalLM.from_pretrained( | ||
model_name, | ||
# device_map=device_map, | ||
# device_map="auto", | ||
torch_dtype=torch.float16, | ||
# max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"}, | ||
# load_in_8bit=eight_bit, | ||
# from_tf=True, | ||
low_cpu_mem_usage=True, | ||
load_in_8bit=False, | ||
cache_dir="cache" | ||
).cuda() | ||
|
||
generator = model.generate | ||
|
||
|
||
load_model("chatDoctor100k/") | ||
|
||
First_chat = "ChatDoctor: I am ChatDoctor, what medical questions do you have?" | ||
print(First_chat) | ||
|
||
|
||
def go(): | ||
invitation = "ChatDoctor: " | ||
human_invitation = "Patient: " | ||
|
||
# input | ||
msg = input(human_invitation) | ||
print("") | ||
|
||
response = csv_prompter(generator,tokenizer,msg) | ||
|
||
print("") | ||
print(invitation + response) | ||
print("") | ||
|
||
while True: | ||
go() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import torch | ||
from llama_index import WikipediaReader | ||
|
||
import pandas as pd | ||
|
||
|
||
|
||
def csv_prompter(generator, tokenizer, question): | ||
|
||
|
||
fulltext = "A question is provided below. Given the question, extract " + \ | ||
"keywords from the text. Focus on extracting the keywords that we can use " + \ | ||
"to best lookup answers to the question. \n" + \ | ||
"---------------------\n" + \ | ||
"{}\n".format(question) + \ | ||
"---------------------\n" + \ | ||
"Provide keywords in the following comma-separated format.\nKeywords: " | ||
|
||
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda() | ||
|
||
with torch.no_grad(): | ||
generated_ids = generator( | ||
gen_in, | ||
max_new_tokens=512, | ||
use_cache=True, | ||
pad_token_id=tokenizer.eos_token_id, | ||
num_return_sequences=1, | ||
do_sample=True, | ||
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx: | ||
temperature=0.5, # default: 1.0 | ||
top_k=50, # default: 50 | ||
top_p=1.0, # default: 1.0 | ||
early_stopping=True, | ||
) | ||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[ | ||
0] # for some reason, batch_decode returns an array of one element? | ||
text_without_prompt = generated_text[len(fulltext):] | ||
response = text_without_prompt | ||
response = response.split("===")[0] | ||
response.strip() | ||
print(response) | ||
keywords = response.split(", ") | ||
print(keywords) | ||
|
||
df = pd.read_csv('disease_symptom.csv') | ||
divided_text = [] | ||
csvdata = df.to_dict('records') | ||
for csv_item in range(0,len(csvdata),6): | ||
csv_text = str(csvdata[csv_item:csv_item+6]).replace("}, {", "\n\n").replace("[", "").replace("]", "").replace("\"", "") | ||
divided_text.append(csv_text) | ||
|
||
answer_llama = "" | ||
|
||
score_textlist = [0] * len(divided_text) | ||
|
||
for i, chunk in enumerate(divided_text): | ||
for t, keyw in enumerate(keywords): | ||
if keyw.lower() in chunk.lower(): | ||
score_textlist[i] = score_textlist[i] + 1 | ||
|
||
answer_list = [] | ||
divided_text = [item for _, item in sorted(zip(score_textlist, divided_text), reverse=True)] | ||
divided_text.append("_") | ||
for i, chunk in enumerate(divided_text): | ||
|
||
if i < 4 and not i == int(len(divided_text) - 1): | ||
fulltext = "{}".format(chunk) + \ | ||
"\n---------------------\n" + \ | ||
"Based on the diseases and corresponding symptoms in the Table above, " + \ | ||
"answer the question: {}\n".format(question) + \ | ||
"Disease name and corresponding symptoms: " | ||
elif i == int(len(divided_text) - 1) and len(answer_list) > 1: | ||
fulltext = "The original question is as follows: {}\n".format(question) + \ | ||
"We have provided existing answers:\n" + \ | ||
"------------\n" + \ | ||
"{}\n".format(str("\n\n".join(answer_list))) + \ | ||
"------------\n" + \ | ||
"The best one answer: " | ||
else: | ||
continue | ||
|
||
print(fulltext) | ||
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda() | ||
|
||
with torch.no_grad(): | ||
generated_ids = generator( | ||
gen_in, | ||
max_new_tokens=512, | ||
use_cache=True, | ||
pad_token_id=tokenizer.eos_token_id, | ||
num_return_sequences=1, | ||
do_sample=True, | ||
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx: | ||
temperature=0.5, # default: 1.0 | ||
top_k=50, # default: 50 | ||
top_p=1.0, # default: 1.0 | ||
early_stopping=True, | ||
) | ||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
text_without_prompt = generated_text[len(fulltext):] | ||
|
||
answer_llama = text_without_prompt | ||
print() | ||
print("\nAnswer: " + answer_llama) | ||
print() | ||
answer_list.append(answer_llama) | ||
|
||
return answer_llama | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
disease,Symptom | ||
Endophthalmitis,"['Pain in eye', 'Eye redness', 'Swollen eye', 'Diminished vision', 'Itchiness of eye', 'Mass on eyelid', 'Paresthesia', 'Painful sinuses', 'Skin lesion', 'Foot or toe pain', 'Lacrimation', 'Painful urination']" | ||
Anxiety,"['Anxiety and nervousness', 'Depression', 'Sharp chest pain', 'Depressive or psychotic symptoms', 'Shortness of breath', 'Headache', 'Insomnia', 'Palpitations', 'Abnormal involuntary movements', 'Irregular heartbeat', 'Fears and phobias', 'Increased heart rate']" | ||
Anal fissure,"['Rectal bleeding', 'Pain of the anus', 'Blood in stool', 'Constipation', 'Sharp abdominal pain', 'Changes in stool appearance', 'Lower body pain', 'Irregular belly button', 'Cramps and spasms', 'Vaginal dryness', 'Symptoms of prostate', 'Itchiness of eye']" | ||
Hiatal hernia,"['Sharp abdominal pain', 'Sharp chest pain', 'Nausea', 'Vomiting', 'Difficulty in swallowing', 'Burning abdominal pain', 'Dizziness', 'Heartburn', 'Back pain', 'Upper abdominal pain', 'Vomiting blood', 'Regurgitation']" | ||
Poisoning due to antidepressants,"['Depression', 'Disturbance of memory', 'Drug abuse', 'Headache', 'Sleepiness', 'Abusing alcohol', 'Allergic reaction', 'Weakness', 'Fainting', 'Muscle pain', 'Emotional symptoms', 'Back weakness']" | ||
Open wound of the jaw,"['Facial pain', 'Lip swelling', 'Symptoms of the face', 'Wrist swelling', 'Mouth pain']" | ||
Pulmonary eosinophilia,"['Cough', 'Shortness of breath', 'Fever', 'Nasal congestion', 'Sharp chest pain', 'Difficulty breathing', 'Dizziness', 'Ache all over', 'Nosebleed', 'Redness in ear', 'Fluid retention', 'Paresthesia']" | ||
Poisoning due to anticonvulsants,"['Problems with movement', 'Seizures', 'Dizziness', 'Vomiting', 'Feeling ill', 'Depression', 'Headache', 'Difficulty speaking', 'Ringing in ear', 'Itching of skin', 'Pain in eye', 'Slurring words']" | ||
Pilonidal cyst,"['Skin growth', 'Lower body pain', 'Ache all over', 'Skin swelling', 'Pain of the anus', 'Bones are painful', 'Back mass or lump', 'Mass or swelling around the anus', 'Pelvic pain', 'Fluid retention', 'Low back swelling', 'Irregular appearing scalp']" | ||
Dislocation of the hip,"['Hip pain', 'Knee pain', 'Leg pain', 'Back pain', 'Groin pain', 'Hip stiffness or tightness', 'Abusing alcohol']" | ||
Lung cancer,"['Shortness of breath', 'Cough', 'Fatigue', 'Decreased appetite', 'Hemoptysis', 'Drainage in throat', 'Leg weakness', 'Smoking problems']" | ||
Tooth disorder,"['Toothache', 'Facial pain', 'Gum pain', 'Ear pain', 'Headache', 'Mouth pain', 'Peripheral edema', 'Jaw swelling', 'Pain in gums', 'Bleeding gums', 'Mouth ulcer', 'Swollen lymph nodes']" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os, json, itertools, bisect, gc | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | ||
import transformers | ||
import torch | ||
from accelerate import Accelerator | ||
import accelerate | ||
import time | ||
|
||
model = None | ||
tokenizer = None | ||
generator = None | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
from wiki_reader import wiki_prompter | ||
|
||
def load_model(model_name, eight_bit=0, device_map="auto"): | ||
global model, tokenizer, generator | ||
|
||
print("Loading " + model_name + "...") | ||
|
||
if device_map == "zero": | ||
device_map = "balanced_low_0" | ||
|
||
# config | ||
gpu_count = torch.cuda.device_count() | ||
print('gpu_count', gpu_count) | ||
|
||
tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name) | ||
model = transformers.LLaMAForCausalLM.from_pretrained( | ||
model_name, | ||
# device_map=device_map, | ||
# device_map="auto", | ||
torch_dtype=torch.float16, | ||
# max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"}, | ||
# load_in_8bit=eight_bit, | ||
# from_tf=True, | ||
low_cpu_mem_usage=True, | ||
load_in_8bit=False, | ||
cache_dir="cache" | ||
).cuda() | ||
|
||
generator = model.generate | ||
|
||
|
||
load_model("chatDoctor100k/") | ||
First_chat = "ChatDoctor: I am ChatDoctor, what medical questions do you have?" | ||
print(First_chat) | ||
|
||
def go(): | ||
invitation = "ChatDoctor: " | ||
human_invitation = "Patient: " | ||
|
||
# input | ||
msg = input(human_invitation) | ||
print("") | ||
|
||
response = wiki_prompter(generator,tokenizer,msg) | ||
|
||
print() | ||
print(invitation + response) | ||
print() | ||
|
||
while True: | ||
go() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from llama_index import WikipediaReader | ||
|
||
|
||
def divide_string(wiki_page, word_limit=50): | ||
divided_text = [] | ||
for each_page in wiki_page: | ||
words = each_page[0].text.split() | ||
|
||
for i in range(0, len(words), word_limit): | ||
chunk = ' '.join(words[i:i+word_limit]) | ||
divided_text.append(chunk) | ||
|
||
return divided_text | ||
|
||
|
||
|
||
def wiki_prompter(generator,tokenizer,question): | ||
|
||
|
||
fulltext = "A question is provided below. Given the question, extract " +\ | ||
"keywords from the text. Focus on extracting the keywords that we can use " +\ | ||
"to best lookup answers to the question. \n" +\ | ||
"---------------------\n" +\ | ||
"{}\n".format(question) +\ | ||
"---------------------\n" +\ | ||
"Provide keywords in the following comma-separated format.\nKeywords: " | ||
|
||
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda() | ||
|
||
|
||
with torch.no_grad(): | ||
generated_ids = generator( | ||
gen_in, | ||
max_new_tokens=512, | ||
use_cache=True, | ||
pad_token_id=tokenizer.eos_token_id, | ||
num_return_sequences=1, | ||
do_sample=True, | ||
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx: | ||
temperature=0.5, # default: 1.0 | ||
top_k=50, # default: 50 | ||
top_p=1.0, # default: 1.0 | ||
early_stopping=True, | ||
) | ||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element? | ||
text_without_prompt = generated_text[len(fulltext):] | ||
response = text_without_prompt | ||
response = response.split("===")[0] | ||
response.strip() | ||
print(response) | ||
keywords = response.split(", ") | ||
print(keywords) | ||
|
||
wiki_docs=[] | ||
for keyw in keywords: | ||
try: | ||
wiki_one = WikipediaReader().load_data(pages=[keyw], auto_suggest=False) | ||
wiki_docs.append(wiki_one) | ||
except: | ||
print("No wiki: "+keyw) | ||
|
||
|
||
divided_text = divide_string(wiki_docs, 250) | ||
|
||
answer_llama="" | ||
|
||
score_textlist = [0] * len(divided_text) | ||
|
||
for i, chunk in enumerate(divided_text): | ||
for t, keyw in enumerate(keywords): | ||
if keyw.lower() in chunk.lower(): | ||
score_textlist[i]=score_textlist[i]+1 | ||
|
||
answer_list=[] | ||
divided_text = [item for _, item in sorted(zip(score_textlist, divided_text), reverse=True)] | ||
divided_text.append("_") | ||
for i, chunk in enumerate(divided_text): | ||
if i<4 and not i==int(len(divided_text)-1): | ||
fulltext = "Context information is below. \n" +\ | ||
"---------------------\n" +\ | ||
"{}".format(chunk) +\ | ||
"\n---------------------\n" +\ | ||
"Given the context information and not prior knowledge, " +\ | ||
"answer the question: {}\n".format(question) +\ | ||
"Response: " | ||
elif i==int(len(divided_text)-1) and len(answer_list)>1 : | ||
fulltext = "The original question is as follows: {}\n".format(question) +\ | ||
"We have provided existing answers:\n" +\ | ||
"------------\n" +\ | ||
"{}\n".format(str("\n\n".join(answer_list))) +\ | ||
"------------\n" +\ | ||
"The best one answer: " | ||
else: | ||
continue | ||
|
||
print(fulltext) | ||
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda() | ||
|
||
|
||
with torch.no_grad(): | ||
generated_ids = generator( | ||
gen_in, | ||
max_new_tokens=512, | ||
use_cache=True, | ||
pad_token_id=tokenizer.eos_token_id, | ||
num_return_sequences=1, | ||
do_sample=True, | ||
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx: | ||
temperature=0.5, # default: 1.0 | ||
top_k=50, # default: 50 | ||
top_p=1.0, # default: 1.0 | ||
early_stopping=True, | ||
) | ||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
text_without_prompt = generated_text[len(fulltext):] | ||
|
||
|
||
answer_llama = text_without_prompt | ||
print() | ||
print("\nAnswer: " + answer_llama) | ||
print() | ||
answer_list.append(answer_llama) | ||
|
||
return answer_llama | ||
|
||
|