diff --git a/pyimagesearch/config.py b/pyimagesearch/config.py new file mode 100644 index 0000000..63b49e5 --- /dev/null +++ b/pyimagesearch/config.py @@ -0,0 +1,47 @@ +# import the necessary packages +import torch +import os + +# base path of the dataset +DATASET_PATH = os.path.join("dataset", "train") + +# define the path to the images and masks dataset +IMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "images") +MASK_DATASET_PATH = os.path.join(DATASET_PATH, "masks") + +# define the test split +TEST_SPLIT = 0.15 + +# determine the device to be used for training and evaluation +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +# determine if we will be pinning memory during data loading +PIN_MEMORY = True if DEVICE == "cuda" else False + +# define the number of channels in the input, number of classes, +# and number of levels in the U-Net model +NUM_CHANNELS = 1 +NUM_CLASSES = 1 +NUM_LEVELS = 3 + +# initialize learning rate, number of epochs to train for, and the +# batch size +INIT_LR = 0.001 +NUM_EPOCHS = 40 +BATCH_SIZE = 64 + +# define the input image dimensions +INPUT_IMAGE_WIDTH = 128 +INPUT_IMAGE_HEIGHT = 128 + +# define threshold to filter weak predictions +THRESHOLD = 0.5 + +# define the path to the base output directory +BASE_OUTPUT = "output" + +# define the path to the output serialized model, model training +# plot, and testing image paths +MODEL_PATH = os.path.join(BASE_OUTPUT, "unet_tgs_salt.pth") +PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"]) +TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"]) \ No newline at end of file diff --git a/pyimagesearch/dataset.py b/pyimagesearch/dataset.py new file mode 100644 index 0000000..3ffa276 --- /dev/null +++ b/pyimagesearch/dataset.py @@ -0,0 +1,34 @@ +# import the necessary packages +from torch.utils.data import Dataset +import cv2 + +class SegmentationDataset(Dataset): + def __init__(self, imagePaths, maskPaths, transforms): + # store the image and mask filepaths, and augmentation + # transforms + self.imagePaths = imagePaths + self.maskPaths = maskPaths + self.transforms = transforms + + def __len__(self): + # return the number of total samples contained in the dataset + return len(self.imagePaths) + + def __getitem__(self, idx): + # grab the image path from the current index + imagePath = self.imagePaths[idx] + + # load the image from disk, swap its channels from BGR to RGB, + # and read the associated mask from disk in grayscale mode + image = cv2.imread(imagePath) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + mask = cv2.imread(self.maskPaths[idx], 0) + + # check to see if we are applying any transformations + if self.transforms is not None: + # apply the transformations to both image and its mask + image = self.transforms(image) + mask = self.transforms(mask) + + # return a tuple of the image and its mask + return (image, mask) \ No newline at end of file diff --git a/pyimagesearch/model.py b/pyimagesearch/model.py new file mode 100644 index 0000000..496d014 --- /dev/null +++ b/pyimagesearch/model.py @@ -0,0 +1,124 @@ +# import the necessary packages +from . import config +from torch.nn import ConvTranspose2d +from torch.nn import Conv2d +from torch.nn import MaxPool2d +from torch.nn import Module +from torch.nn import ModuleList +from torch.nn import ReLU +from torchvision.transforms import CenterCrop +from torch.nn import functional as F +import torch + +class Block(Module): + def __init__(self, inChannels, outChannels): + super().__init__() + # store the convolution and RELU layers + self.conv1 = Conv2d(inChannels, outChannels, 3) + self.relu = ReLU() + self.conv2 = Conv2d(outChannels, outChannels, 3) + + def forward(self, x): + # apply CONV => RELU => CONV block to the inputs and return it + return self.conv2(self.relu(self.conv1(x))) + +class Encoder(Module): + def __init__(self, channels=(3, 16, 32, 64)): + super().__init__() + # store the encoder blocks and maxpooling layer + self.encBlocks = ModuleList( + [Block(channels[i], channels[i + 1]) + for i in range(len(channels) - 1)]) + self.pool = MaxPool2d(2) + + def forward(self, x): + # initialize an empty list to store the intermediate outputs + blockOutputs = [] + + # loop through the encoder blocks + for block in self.encBlocks: + # pass the inputs through the current encoder block, store + # the outputs, and then apply maxpooling on the output + x = block(x) + blockOutputs.append(x) + x = self.pool(x) + + # return the list containing the intermediate outputs + return blockOutputs + + +class Decoder(Module): + def __init__(self, channels=(64, 32, 16)): + super().__init__() + # initialize the number of channels, upsampler blocks, and + # decoder blocks + self.channels = channels + self.upconvs = ModuleList( + [ConvTranspose2d(channels[i], channels[i + 1], 2, 2) + for i in range(len(channels) - 1)]) + self.dec_blocks = ModuleList( + [Block(channels[i], channels[i + 1]) + for i in range(len(channels) - 1)]) + + def forward(self, x, encFeatures): + # loop through the number of channels + for i in range(len(self.channels) - 1): + # pass the inputs through the upsampler blocks + x = self.upconvs[i](x) + + # crop the current features from the encoder blocks, + # concatenate them with the current upsampled features, + # and pass the concatenated output through the current + # decoder block + encFeat = self.crop(encFeatures[i], x) + x = torch.cat([x, encFeat], dim=1) + x = self.dec_blocks[i](x) + + # return the final decoder output + return x + + def crop(self, encFeatures, x): + # grab the dimensions of the inputs, and crop the encoder + # features to match the dimensions + (_, _, H, W) = x.shape + encFeatures = CenterCrop([H, W])(encFeatures) + + # return the cropped features + return encFeatures + + +class UNet(Module): + def __init__(self, encChannels=(3, 16, 32, 64), + decChannels=(64, 32, 16), + nbClasses=1, retainDim=True, + outSize=(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)): + super().__init__() + # initialize the encoder and decoder + self.encoder = Encoder(encChannels) + self.decoder = Decoder(decChannels) + + # initialize the regression head and store the class variables + self.head = Conv2d(decChannels[-1], nbClasses, 1) + self.retainDim = retainDim + self.outSize = outSize + + def forward(self, x): + # grab the features from the encoder + encFeatures = self.encoder(x) + + # pass the encoder features through decoder making sure that + # their dimensions are suited for concatenation + decFeatures = self.decoder(encFeatures[::-1][0], + encFeatures[::-1][1:]) + + # pass the decoder features through the regression head to + # obtain the segmentation mask + map = self.head(decFeatures) + + # check to see if we are retaining the original output + # dimensions and if so, then resize the output to match them + if self.retainDim: + map = F.interpolate(map, self.outSize) + + # return the segmentation map + return map \ No newline at end of file