-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
26 lines (17 loc) · 884 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# coding: utf-8
from typing import Dict, Tuple, List
from torchvision.datasets import ImageFolder
class GlyphData(ImageFolder):
def __init__(self, class_to_idx: Dict[str, int], root: str = "prepared_data/train/", *args, **kwargs):
"""
Reading GlyphDataset as an ImageFolder with the custom class-to-id mapping.
:param class_to_idx: custom mapping of labels (str) to IDs (int)
:param root: train or test data directory
"""
self.classes_list = ["UNKNOWN" for _ in range(max(class_to_idx.values()) + 1)]
self.classes_map = class_to_idx
for k, v in class_to_idx.items():
self.classes_list[v] = k
super(GlyphData, self).__init__(root=root, *args, **kwargs)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
return self.classes_list, self.classes_map