-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset_processing.py
60 lines (46 loc) · 2.32 KB
/
dataset_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import re
import yaml
import argparse
import logging
import pandas as pd
from sklearn.model_selection import train_test_split
from typing import Text
def DatasetProcessor(config:Text, split_ratio:Text)->None:
with open(config) as cf:
config = yaml.safe_load(cf)
logging.basicConfig(level=logging.INFO)
print(split_ratio)
image_expression = re.compile("((in |on |of )?(the |this )?(image\d*) \?)")
with open(os.path.join(config["data"]["dataset_folder"], config["data"]["all_qa_pairs_file"])) as f:
qa_data = [l.replace("\n","") for l in f.readlines()]
logging.info("Loaded all question-answer pairs")
df = pd.DataFrame({config["data"]["question_col"]: [], config["data"]["answer_col"]: [], config["data"]["image_col"]:[]})
for i in range(0, len(qa_data),2):
img_id = image_expression.findall(qa_data[i])[0][3]
question = qa_data[i].replace(image_expression.findall(qa_data[i])[0][0], "")
record = {
config["data"]["question_col"]: question,
config["data"]["answer_col"]: qa_data[i+1],
config["data"]["image_col"]: img_id,
}
df = df.append(record, ignore_index=True)
df.to_csv("data.csv", index=None)
logging.info("Creating space of all possible answers")
answer_space = []
for ans in df.answer.to_list():
answer_space = answer_space + [ans] if "," not in ans else answer_space + ans.replace(" ", "").split(",")
answer_space = list(set(answer_space))
answer_space.sort()
with open(os.path.join(config["data"]["dataset_folder"], config["data"]["answer_space"]), "w") as f:
f.writelines("\n".join(answer_space))
logging.info("Splitting into train & eval sets")
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df.to_csv(os.path.join(config["data"]["dataset_folder"], config["data"]["train_dataset"]), index=None)
test_df.to_csv(os.path.join(config["data"]["dataset_folder"], config["data"]["eval_dataset"]), index=None)
if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument('--config', dest='config', required=True)
args_parser.add_argument('--split_ratio', dest='split_ratio', required=True)
args = args_parser.parse_args()
DatasetProcessor(args.config, args.split_ratio)