# import the necessary packages from . import config from torch.nn import ConvTranspose2d from torch.nn import Conv2d from torch.nn import MaxPool2d from torch.nn import Module from torch.nn import ModuleList from torch.nn import ReLU from torchvision.transforms import CenterCrop from torch.nn import functional as F import torch class Block(Module): def __init__(self, inChannels, outChannels): super().__init__() # store the convolution and RELU layers self.conv1 = Conv2d(inChannels, outChannels, 3) self.relu = ReLU() self.conv2 = Conv2d(outChannels, outChannels, 3) def forward(self, x): # apply CONV => RELU => CONV block to the inputs and return it return self.conv2(self.relu(self.conv1(x))) class Encoder(Module): def __init__(self, channels=(3, 16, 32, 64)): super().__init__() # store the encoder blocks and maxpooling layer self.encBlocks = ModuleList( [Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]) self.pool = MaxPool2d(2) def forward(self, x): # initialize an empty list to store the intermediate outputs blockOutputs = [] # loop through the encoder blocks for block in self.encBlocks: # pass the inputs through the current encoder block, store # the outputs, and then apply maxpooling on the output x = block(x) blockOutputs.append(x) x = self.pool(x) # return the list containing the intermediate outputs return blockOutputs class Decoder(Module): def __init__(self, channels=(64, 32, 16)): super().__init__() # initialize the number of channels, upsampler blocks, and # decoder blocks self.channels = channels self.upconvs = ModuleList( [ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)]) self.dec_blocks = ModuleList( [Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]) def forward(self, x, encFeatures): # loop through the number of channels for i in range(len(self.channels) - 1): # pass the inputs through the upsampler blocks x = self.upconvs[i](x) # crop the current features from the encoder blocks, # concatenate them with the current upsampled features, # and pass the concatenated output through the current # decoder block encFeat = self.crop(encFeatures[i], x) x = torch.cat([x, encFeat], dim=1) x = self.dec_blocks[i](x) # return the final decoder output return x def crop(self, encFeatures, x): # grab the dimensions of the inputs, and crop the encoder # features to match the dimensions (_, _, H, W) = x.shape encFeatures = CenterCrop([H, W])(encFeatures) # return the cropped features return encFeatures class UNet(Module): def __init__(self, encChannels=(3, 16, 32, 64), decChannels=(64, 32, 16), nbClasses=1, retainDim=True, outSize=(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)): super().__init__() # initialize the encoder and decoder self.encoder = Encoder(encChannels) self.decoder = Decoder(decChannels) # initialize the regression head and store the class variables self.head = Conv2d(decChannels[-1], nbClasses, 1) self.retainDim = retainDim self.outSize = outSize def forward(self, x): # grab the features from the encoder encFeatures = self.encoder(x) # pass the encoder features through decoder making sure that # their dimensions are suited for concatenation decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:]) # pass the decoder features through the regression head to # obtain the segmentation mask map = self.head(decFeatures) # check to see if we are retaining the original output # dimensions and if so, then resize the output to match them if self.retainDim: map = F.interpolate(map, self.outSize) # return the segmentation map return map