-
Hi, I am performing a segmentation task using Albumentations and PyTorch. However, during my first attempts, I noticed that the mask has type uint8, which ends up causing a runtime error in PyTorch: In the end, I figured out that it was necessary to change the format of the mask using Let me describe the situation: Say I created a custom dataset that uses OpenCV to fit nicely with Albumentations: class MyCustomDataset(Dataset):
def __init__(self, csv_file, transform = None):
self.files = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, index):
input_img_path = self.files.inputs[index]
input_img_bgr = cv2.imread(input_img_path)
input_img = cv2.cvtColor(input_img_bgr, cv2.COLOR_BGR2RGB)
mask_img_path = self.files.masks[index]
mask_img_bgr = cv2.imread(mask_img_path)
mask_img = cv2.cvtColor(mask_img_bgr, cv2.COLOR_BGR2RGB)
mask_img = mask_img.astype(np.float32) # NEED TO CHANGE FORMAT HERE!
if self.transform:
augmentation = self.transform(image=input_img, mask=mask_img)
input_img = augmentation["image"]
mask_img = augmentation["mask"]
return input_img, mask_img Then, after applying some transformations, I use transform_train = A.Compose([
A.Normalize(
mean=[0, 0, 0],
std=[1, 1, 1],
max_pixel_value=255.0,
),
ToTensorV2(transpose_mask=True)
])
dataset = MyCustomDataset("dataset.csv", transform=transform_train) Then, I create a data loader. However, if I DONT INCLUDE data_loader = DataLoader(dataset)
for x, y in data_loader:
print(x.dtype, y.dtype) # returns torch.float32 torch.uint8 Isn't By inspecting the source code (see the two links below), I noticed that the functions So I'm left with the question: is it calling |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Looks like we need to add control flag to change mask dtype inside
It is better to call this |
Beta Was this translation helpful? Give feedback.
Looks like we need to add control flag to change mask dtype inside
ToTensorV2
.It is better to call this
mask = mask.to(torch.float32)
after calling transform, because processing auint8
mask is much faster.