# USAGE # python train.py # import the necessary packages from pyimagesearch.dataset import SegmentationDataset from pyimagesearch.model import UNet from pyimagesearch import config from torch.nn import BCEWithLogitsLoss from torch.optim import Adam from torch.utils.data import DataLoader from sklearn.model_selection import train_test_split from torchvision import transforms from imutils import paths from tqdm import tqdm import matplotlib.pyplot as plt import torch import time import os # load the image and mask filepaths in a sorted manner imagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH))) maskPaths = sorted(list(paths.list_images(config.MASK_DATASET_PATH))) # partition the data into training and testing splits using 85% of # the data for training and the remaining 15% for testing split = train_test_split(imagePaths, maskPaths, test_size=config.TEST_SPLIT, random_state=42) # unpack the data split (trainImages, testImages) = split[:2] (trainMasks, testMasks) = split[2:] # write the testing image paths to disk so that we can use then # when evaluating/testing our model print("[INFO] saving testing image paths...") f = open(config.TEST_PATHS, "w") f.write("\n".join(testImages)) f.close() # define transformations transforms = transforms.Compose([transforms.ToPILImage(), transforms.Resize((config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)), transforms.ToTensor()]) # create the train and test datasets trainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks, transforms=transforms) testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks, transforms=transforms) print(f"[INFO] found {len(trainDS)} examples in the training set...") print(f"[INFO] found {len(testDS)} examples in the test set...") # create the training and test data loaders trainLoader = DataLoader(trainDS, shuffle=True, batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY, num_workers=0) testLoader = DataLoader(testDS, shuffle=False, batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY, num_workers=0) # initialize our UNet model unet = UNet().to(config.DEVICE) # initialize loss function and optimizer lossFunc = BCEWithLogitsLoss() opt = Adam(unet.parameters(), lr=config.INIT_LR) # calculate steps per epoch for training and test set trainSteps = len(trainDS) // config.BATCH_SIZE testSteps = len(testDS) // config.BATCH_SIZE # initialize a dictionary to store training history H = {"train_loss": [], "test_loss": []} # loop over epochs print("[INFO] training the network...") startTime = time.time() for e in tqdm(range(config.NUM_EPOCHS)): # set the model in training mode unet.train() # initialize the total training and validation loss totalTrainLoss = 0 totalTestLoss = 0 # loop over the training set for (i, (x, y)) in enumerate(trainLoader): # send the input to the device (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE)) # perform a forward pass and calculate the training loss pred = unet(x) loss = lossFunc(pred, y) # first, zero out any previously accumulated gradients, then # perform backpropagation, and then update model parameters opt.zero_grad() loss.backward() opt.step() # add the loss to the total training loss so far totalTrainLoss += loss # switch off autograd with torch.no_grad(): # set the model in evaluation mode unet.eval() # loop over the validation set for (x, y) in testLoader: # send the input to the device (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE)) # make the predictions and calculate the validation loss pred = unet(x) totalTestLoss += lossFunc(pred, y) # calculate the average training and validation loss avgTrainLoss = totalTrainLoss / trainSteps avgTestLoss = totalTestLoss / testSteps # update our training history H["train_loss"].append(avgTrainLoss.cpu().detach().numpy()) H["test_loss"].append(avgTestLoss.cpu().detach().numpy()) # print the model training and validation information print("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS)) print("Train loss: {:.6f}, Test loss: {:.4f}".format( avgTrainLoss, avgTestLoss)) # display the total time needed to perform the training endTime = time.time() print("[INFO] total time taken to train the model: {:.2f}s".format( endTime - startTime)) # plot the training loss plt.style.use("ggplot") plt.figure() plt.plot(H["train_loss"], label="train_loss") plt.plot(H["test_loss"], label="test_loss") plt.title("Training Loss on Dataset") plt.xlabel("Epoch #") plt.ylabel("Loss") plt.legend(loc="lower left") plt.savefig(config.PLOT_PATH) # serialize the model to disk torch.save(unet, config.MODEL_PATH)