Add files via upload
This commit is contained in:
parent
d0512f6a95
commit
3f6433f7ab
|
@ -0,0 +1,47 @@
|
||||||
|
# import the necessary packages
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
# base path of the dataset
|
||||||
|
DATASET_PATH = os.path.join("dataset", "train")
|
||||||
|
|
||||||
|
# define the path to the images and masks dataset
|
||||||
|
IMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "images")
|
||||||
|
MASK_DATASET_PATH = os.path.join(DATASET_PATH, "masks")
|
||||||
|
|
||||||
|
# define the test split
|
||||||
|
TEST_SPLIT = 0.15
|
||||||
|
|
||||||
|
# determine the device to be used for training and evaluation
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# determine if we will be pinning memory during data loading
|
||||||
|
PIN_MEMORY = True if DEVICE == "cuda" else False
|
||||||
|
|
||||||
|
# define the number of channels in the input, number of classes,
|
||||||
|
# and number of levels in the U-Net model
|
||||||
|
NUM_CHANNELS = 1
|
||||||
|
NUM_CLASSES = 1
|
||||||
|
NUM_LEVELS = 3
|
||||||
|
|
||||||
|
# initialize learning rate, number of epochs to train for, and the
|
||||||
|
# batch size
|
||||||
|
INIT_LR = 0.001
|
||||||
|
NUM_EPOCHS = 40
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
|
# define the input image dimensions
|
||||||
|
INPUT_IMAGE_WIDTH = 128
|
||||||
|
INPUT_IMAGE_HEIGHT = 128
|
||||||
|
|
||||||
|
# define threshold to filter weak predictions
|
||||||
|
THRESHOLD = 0.5
|
||||||
|
|
||||||
|
# define the path to the base output directory
|
||||||
|
BASE_OUTPUT = "output"
|
||||||
|
|
||||||
|
# define the path to the output serialized model, model training
|
||||||
|
# plot, and testing image paths
|
||||||
|
MODEL_PATH = os.path.join(BASE_OUTPUT, "unet_tgs_salt.pth")
|
||||||
|
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])
|
||||||
|
TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])
|
|
@ -0,0 +1,34 @@
|
||||||
|
# import the necessary packages
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
class SegmentationDataset(Dataset):
|
||||||
|
def __init__(self, imagePaths, maskPaths, transforms):
|
||||||
|
# store the image and mask filepaths, and augmentation
|
||||||
|
# transforms
|
||||||
|
self.imagePaths = imagePaths
|
||||||
|
self.maskPaths = maskPaths
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# return the number of total samples contained in the dataset
|
||||||
|
return len(self.imagePaths)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# grab the image path from the current index
|
||||||
|
imagePath = self.imagePaths[idx]
|
||||||
|
|
||||||
|
# load the image from disk, swap its channels from BGR to RGB,
|
||||||
|
# and read the associated mask from disk in grayscale mode
|
||||||
|
image = cv2.imread(imagePath)
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
mask = cv2.imread(self.maskPaths[idx], 0)
|
||||||
|
|
||||||
|
# check to see if we are applying any transformations
|
||||||
|
if self.transforms is not None:
|
||||||
|
# apply the transformations to both image and its mask
|
||||||
|
image = self.transforms(image)
|
||||||
|
mask = self.transforms(mask)
|
||||||
|
|
||||||
|
# return a tuple of the image and its mask
|
||||||
|
return (image, mask)
|
|
@ -0,0 +1,124 @@
|
||||||
|
# import the necessary packages
|
||||||
|
from . import config
|
||||||
|
from torch.nn import ConvTranspose2d
|
||||||
|
from torch.nn import Conv2d
|
||||||
|
from torch.nn import MaxPool2d
|
||||||
|
from torch.nn import Module
|
||||||
|
from torch.nn import ModuleList
|
||||||
|
from torch.nn import ReLU
|
||||||
|
from torchvision.transforms import CenterCrop
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Block(Module):
|
||||||
|
def __init__(self, inChannels, outChannels):
|
||||||
|
super().__init__()
|
||||||
|
# store the convolution and RELU layers
|
||||||
|
self.conv1 = Conv2d(inChannels, outChannels, 3)
|
||||||
|
self.relu = ReLU()
|
||||||
|
self.conv2 = Conv2d(outChannels, outChannels, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# apply CONV => RELU => CONV block to the inputs and return it
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
class Encoder(Module):
|
||||||
|
def __init__(self, channels=(3, 16, 32, 64)):
|
||||||
|
super().__init__()
|
||||||
|
# store the encoder blocks and maxpooling layer
|
||||||
|
self.encBlocks = ModuleList(
|
||||||
|
[Block(channels[i], channels[i + 1])
|
||||||
|
for i in range(len(channels) - 1)])
|
||||||
|
self.pool = MaxPool2d(2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# initialize an empty list to store the intermediate outputs
|
||||||
|
blockOutputs = []
|
||||||
|
|
||||||
|
# loop through the encoder blocks
|
||||||
|
for block in self.encBlocks:
|
||||||
|
# pass the inputs through the current encoder block, store
|
||||||
|
# the outputs, and then apply maxpooling on the output
|
||||||
|
x = block(x)
|
||||||
|
blockOutputs.append(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
|
||||||
|
# return the list containing the intermediate outputs
|
||||||
|
return blockOutputs
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(Module):
|
||||||
|
def __init__(self, channels=(64, 32, 16)):
|
||||||
|
super().__init__()
|
||||||
|
# initialize the number of channels, upsampler blocks, and
|
||||||
|
# decoder blocks
|
||||||
|
self.channels = channels
|
||||||
|
self.upconvs = ModuleList(
|
||||||
|
[ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
|
||||||
|
for i in range(len(channels) - 1)])
|
||||||
|
self.dec_blocks = ModuleList(
|
||||||
|
[Block(channels[i], channels[i + 1])
|
||||||
|
for i in range(len(channels) - 1)])
|
||||||
|
|
||||||
|
def forward(self, x, encFeatures):
|
||||||
|
# loop through the number of channels
|
||||||
|
for i in range(len(self.channels) - 1):
|
||||||
|
# pass the inputs through the upsampler blocks
|
||||||
|
x = self.upconvs[i](x)
|
||||||
|
|
||||||
|
# crop the current features from the encoder blocks,
|
||||||
|
# concatenate them with the current upsampled features,
|
||||||
|
# and pass the concatenated output through the current
|
||||||
|
# decoder block
|
||||||
|
encFeat = self.crop(encFeatures[i], x)
|
||||||
|
x = torch.cat([x, encFeat], dim=1)
|
||||||
|
x = self.dec_blocks[i](x)
|
||||||
|
|
||||||
|
# return the final decoder output
|
||||||
|
return x
|
||||||
|
|
||||||
|
def crop(self, encFeatures, x):
|
||||||
|
# grab the dimensions of the inputs, and crop the encoder
|
||||||
|
# features to match the dimensions
|
||||||
|
(_, _, H, W) = x.shape
|
||||||
|
encFeatures = CenterCrop([H, W])(encFeatures)
|
||||||
|
|
||||||
|
# return the cropped features
|
||||||
|
return encFeatures
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(Module):
|
||||||
|
def __init__(self, encChannels=(3, 16, 32, 64),
|
||||||
|
decChannels=(64, 32, 16),
|
||||||
|
nbClasses=1, retainDim=True,
|
||||||
|
outSize=(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)):
|
||||||
|
super().__init__()
|
||||||
|
# initialize the encoder and decoder
|
||||||
|
self.encoder = Encoder(encChannels)
|
||||||
|
self.decoder = Decoder(decChannels)
|
||||||
|
|
||||||
|
# initialize the regression head and store the class variables
|
||||||
|
self.head = Conv2d(decChannels[-1], nbClasses, 1)
|
||||||
|
self.retainDim = retainDim
|
||||||
|
self.outSize = outSize
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# grab the features from the encoder
|
||||||
|
encFeatures = self.encoder(x)
|
||||||
|
|
||||||
|
# pass the encoder features through decoder making sure that
|
||||||
|
# their dimensions are suited for concatenation
|
||||||
|
decFeatures = self.decoder(encFeatures[::-1][0],
|
||||||
|
encFeatures[::-1][1:])
|
||||||
|
|
||||||
|
# pass the decoder features through the regression head to
|
||||||
|
# obtain the segmentation mask
|
||||||
|
map = self.head(decFeatures)
|
||||||
|
|
||||||
|
# check to see if we are retaining the original output
|
||||||
|
# dimensions and if so, then resize the output to match them
|
||||||
|
if self.retainDim:
|
||||||
|
map = F.interpolate(map, self.outSize)
|
||||||
|
|
||||||
|
# return the segmentation map
|
||||||
|
return map
|
Loading…
Reference in New Issue