That's the question I asked myself when I participated on a LLM competition that ends in a week.
- I found out about the competition only a week before it closed. I only participated because I wanted to know RAG better. I knew I had no chance of winning xd.
- I decided to build a noob RAG system from scratch to solve the competition.
https://challenge.kalapa.vn/portal/vietnamese-medical-question-answering/overview
-
The competition provided dataset with information (causes, symptoms, prevention method, etc.) about ~600 diseases in Vietnamese.
- The dataset contains about 600 files about various diseases from a medical website.
-
Participants use the dataset to answer multiple-choice questions and there exists more than 1 correct answer.
id | question | option_1 | option_2 | option_3 | option_4 | option_5 | option_6 |
---|---|---|---|---|---|---|---|
1 | ...Hương có thể kiểm tra phát hiện bệnh này từ tuần thứ mấy của thai kỳ? | A. Tuần 10 | B.Tuần 20 | C. Tuần 30 | D. Tuần 40 |
- Required output: a binary string with length
n
with elementi
-th is 0 if choicei
-th in the question is incorrect, 1 if it's correct.
Question | Expected Answer |
---|---|
Đâu là triệu chứng của bệnh van tim? A. Khó thở B. Tăng cân nhanh chóng C. Vàng da D. Rụng tóc | 1100 |
- Recover the Vietnamese accents in data file name: use the disease name at the beginning of each file or rename it manually.
For each question, answer 2 questions:
- Which files to retrieve: details here.
- Which chunk of text in the retrieved file to use as context: use
sentence-transformers
to embed the text and usesemantic_search
to gettop-k
related context.
Output: context for LLM.
- For each question, for each proposed answer, ask
llm
if the answer is correct with the provided context from Retrieval. If yes, output1
, else0
. - Output format: to make sure the output is in the required output, I use the method here.
Output: binary string.
- At that moment, I don't know how to use the existed libraries like
langchain
andllamaindex
(I now kinda "pro" inllamaindex
, believe me) and it's only less than a week before the competition ended (too short for me to learn anything new). - I want to understand deeply each part in a RAG system (how many parts a simple RAG have, how to indexing, get embeddings, retrieve context, etc.).
- Because there are ~600 files/diseases and if there are 100 questions to answer then I'd have to do
semantic_search
for like 60000 times. God, I only had Colab's free T4s (my Mac chip sucks at working with LLM,mps
is not fully supported yet). - Instead, I used the
question
for each answer, I compared its similarity with ~600 file/disease names (not the content of the file!).- More specific, I used
ROUGE-l
to measure their similarity. It's like matching words in thequestion
with words in thedisease_name
. - Even more improvement: I also used the
choices
to compare, in case the question doesn't contain enough information.
- More specific, I used
- Then, I takes
top-2
files as context for each question. And for each file, I search fortop-3
most related chunk of texts.
- For each possible choice, "ask" the model if this is a correct answer (output 1) or not (output 0).
instruct_prompt = (
"Bối cảnh: {context}\n\n"
# Just use random few-shot here, it improves the stability of model output.
"Câu hỏi: Cách nào để phòng ngừa bệnh tim mạch?\n"
"Câu trả lời đề xuất: Tập thể dục đều đặn.\n"
"Điểm:1\n\n"
"Câu hỏi: {question}\n"
"Câu trả lời đề xuất: {choice}\n"
"Điểm:"
)
- Since the model only "talks" in probability and numbers, I had to find a way to know is it saying that the answer is correct or not. I did that by using the
logits
output by the model and use asoftmax
function to get the probabilities.
def is_the_answer_correct(model, tokenizer, prompt, correct_str="1", incorrect_str="0"):
"""
Given the context, question and proposed answer in the prompt,
predict if the answer is correct or not based on the probability that it outputs the correct vs incorrect string.
"""
input_ids = tokenizer(prompt, return_tensors="pt")['input_ids']
with torch.no_grad():
# Get the logit at the end of the sentence.
logits = model(input_ids=input_ids).logits[0, -1]
correct_str_token = tokenizer(correct_str).input_ids[-1]
incorrect_str_token = tokenizer(incorrect_str).input_ids[-1]
# Squeeze the logit on 2 tokens and put it through a softmax.
probs = torch.nn.functional.softmax(
torch.tensor([
logits[incorrect_str_token],
logits[correct_str_token]
]).float(),
dim=0
).detach().cpu().numpy()
return {0: incorrect_str, 1: correct_str}[np.argmax(probs)], probs
- Final score: 0.6468 (rank 21/90 🫣).
- Break down:
- 0.2744: I used a Vietnamese GPT model, 1-shot prompt, both English and Vietnamese in a lengthy, redundant prompt.
- 0.47: switch to
bloomz-7b1-mt
, Vietnamese prompt, 1-shot prompt. - 0.5915: fix cases that model only output zeros (by taking the answer with highest probability of outputing the correct string).
- 0.6118: fix some coding bugs.
- 0.6468: 2-shot prompt, modify prompt punctuations and stuffs.
- Tried but failed (why I am writing like I am the winner??)
- increase
k
in topk files - increase few-shot prompt (>2)
- fine-tune the 4-bit bloomz model as a classifier to output 1 for correct answer, 0 otherwise. (the baseline score increased but I don't know how to prompt it correctly).
- increase
- Skills learned:
- use
sentence-transformer
to do semantice search and retrieval. - implement basic
sentence-splitter
(same asllamaindex
). - how to do multiple choice question with LLM.
- be creative and apply small tricks to largely improve the result.
- use
- A running notebook can be found here: https://www.kaggle.com/code/quangphm/multiple-choice-rag/notebook