diff --git a/midv500/convert_dataset.py b/midv500/convert_dataset.py index eebcf07..b049567 100644 --- a/midv500/convert_dataset.py +++ b/midv500/convert_dataset.py @@ -6,6 +6,7 @@ from tqdm import tqdm from midv500.utils import list_annotation_paths_recursively, get_bbox_inside_image, create_dir + def convert(root_dir: str, export_dir: str): """ Walks inside root_dir (should oly contain original midv500 dataset folders), @@ -31,8 +32,8 @@ def convert(root_dir: str, export_dir: str): print("Converting to coco.") for ind, rel_annotation_path in enumerate(tqdm(annotation_paths)): # get image path - rel_image_path = rel_annotation_path.replace("ground_truth","images") - rel_image_path = rel_image_path.replace("json","tif") + rel_image_path = rel_annotation_path.replace("ground_truth", "images") + rel_image_path = rel_image_path.replace("json", "tif") # load image abs_image_path = os.path.join(root_dir, rel_image_path) @@ -63,9 +64,9 @@ def convert(root_dir: str, export_dir: str): # create mask from poly coords mask = np.zeros(image.shape, dtype=np.uint8) mask_coords_np = np.array(mask_coords, dtype=np.int32) - cv2.fillPoly(mask, mask_coords_np.reshape(-1, 4, 2), color=(255,255,255)) + cv2.fillPoly(mask, mask_coords_np.reshape(-1, 4, 2), color=(255, 255, 255)) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) - mask = cv2.threshold(mask, 0,255, cv2.THRESH_BINARY)[1] + mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)[1] # get voc style bounding box coordinates [minx, miny, maxx, maxy] of the mask label_xmin = min([pos[0] for pos in mask_coords]) @@ -83,7 +84,7 @@ def convert(root_dir: str, export_dir: str): annotation_dict = dict() annotation_dict["iscrowd"] = 0 annotation_dict["image_id"] = image_dict['id'] - annotation_dict['category_id'] = 1 # id card + annotation_dict['category_id'] = 1 # id card annotation_dict['ignore'] = 0 annotation_dict['id'] = ind diff --git a/midv500/download_dataset.py b/midv500/download_dataset.py index cad1eff..1dab8e5 100644 --- a/midv500/download_dataset.py +++ b/midv500/download_dataset.py @@ -53,6 +53,7 @@ 'ftp://smartengines.com/midv-500/dataset/49_usa_ssn82.zip', 'ftp://smartengines.com/midv-500/dataset/50_xpo_id.zip'] + def download_dataset(download_dir: str): """ This script downloads the MIDV-500 dataset and unzips the folders. @@ -66,18 +67,19 @@ def download_dataset(download_dir: str): print('Downloaded:', link[40:]) # unzip zip file print('Unzipping:', link[40:]) - zip_path = os.path.join(download_dir,link[40:]) + zip_path = os.path.join(download_dir, link[40:]) unzip(zip_path, download_dir) print('Unzipped:', link[40:].replace('.zip', '')) # remove zip file os.remove(zip_path) + if __name__ == '__main__': # construct the argument parser ap = argparse.ArgumentParser() # add the arguments to the parser - ap.add_argument("download_dir", default = "data/", help="Directory for MIDV-500 dataset to be downloaded.") + ap.add_argument("download_dir", default="data/", help="Directory for MIDV-500 dataset to be downloaded.") args = vars(ap.parse_args()) # download dataset diff --git a/midv500/utils.py b/midv500/utils.py index 74dbffe..f374f6f 100644 --- a/midv500/utils.py +++ b/midv500/utils.py @@ -87,7 +87,7 @@ def list_annotation_paths_recursively(directory: str, ignore_background_only_one continue relative_filepath = abs_filepath.split(directory)[-1] - relative_filepath = relative_filepath.replace("\\", "/") # for windows + relative_filepath = relative_filepath.replace("\\", "/") # for windows relative_filepath_list.append(relative_filepath) number_of_files = len(relative_filepath_list) diff --git a/setup.cfg b/setup.cfg index b8460f2..825916b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] max-line-length = 119 exclude =.git,__pycache__,docs/source/conf.py,build,dist -ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006 +ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,E722,W503,B006 inline-quotes = " [mypy]