-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_MIM.py
64 lines (53 loc) · 2.23 KB
/
dataset_MIM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import pickle
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomVerticalFlip
from utils.helpers import Fix_RandomRotation
class vessel_dataset(Dataset):
def __init__(self, path, mode, is_val=False, split=None):
self.mode = mode
self.is_val = is_val
self.data_path = os.path.join(path, f"{mode}_pro")
self.data_file = os.listdir(self.data_path)
#存储图片文件
self.img_file = self._select_img(self.data_file)
if split is not None and mode == "training":
assert split > 0 and split < 1
if not is_val:
self.img_file = self.img_file[:int(split*len(self.img_file))]
else:
self.img_file = self.img_file[int(split*len(self.img_file)):]
self.transforms = Compose([
RandomHorizontalFlip(p=0.5),
RandomVerticalFlip(p=0.5),
Fix_RandomRotation(),
])
def __getitem__(self, idx):
img_file = self.img_file[idx]
with open(file=os.path.join(self.data_path, img_file), mode='rb') as file:
img = torch.from_numpy(pickle.load(file)).float()
gt_file = "gt" + img_file[3:]
with open(file=os.path.join(self.data_path, gt_file), mode='rb') as file:
gt = torch.from_numpy(pickle.load(file)).float()
mask_file = "mask" + img_file[3:]
with open(file=os.path.join(self.data_path, mask_file), mode='rb') as file:
mask = torch.from_numpy(pickle.load(file)).float()
if self.mode == "training" and not self.is_val:
seed = torch.seed()
torch.manual_seed(seed)
img = self.transforms(img)
torch.manual_seed(seed)
gt = self.transforms(gt)
torch.manual_seed(seed)
mask = self.transforms(mask)
return img, gt, mask
#对于image来说 前三位是img 对于GT来说前2位是gt
def _select_img(self, file_list):
img_list = []
for file in file_list:
if file[:3] == "img":
img_list.append(file)
return img_list
def __len__(self):
return len(self.img_file)