34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
# import the necessary packages
|
|
from torch.utils.data import Dataset
|
|
import cv2
|
|
|
|
class SegmentationDataset(Dataset):
|
|
def __init__(self, imagePaths, maskPaths, transforms):
|
|
# store the image and mask filepaths, and augmentation
|
|
# transforms
|
|
self.imagePaths = imagePaths
|
|
self.maskPaths = maskPaths
|
|
self.transforms = transforms
|
|
|
|
def __len__(self):
|
|
# return the number of total samples contained in the dataset
|
|
return len(self.imagePaths)
|
|
|
|
def __getitem__(self, idx):
|
|
# grab the image path from the current index
|
|
imagePath = self.imagePaths[idx]
|
|
|
|
# load the image from disk, swap its channels from BGR to RGB,
|
|
# and read the associated mask from disk in grayscale mode
|
|
image = cv2.imread(imagePath)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
mask = cv2.imread(self.maskPaths[idx], 0)
|
|
|
|
# check to see if we are applying any transformations
|
|
if self.transforms is not None:
|
|
# apply the transformations to both image and its mask
|
|
image = self.transforms(image)
|
|
mask = self.transforms(mask)
|
|
|
|
# return a tuple of the image and its mask
|
|
return (image, mask) |