90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
# USAGE
|
|
# python predict.py
|
|
|
|
# import the necessary packages
|
|
from pyimagesearch import config
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
import os
|
|
|
|
def prepare_plot(origImage, origMask, predMask):
|
|
# initialize our figure
|
|
figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
|
|
|
|
# plot the original image, its mask, and the predicted mask
|
|
ax[0].imshow(origImage)
|
|
ax[1].imshow(origMask)
|
|
ax[2].imshow(predMask)
|
|
|
|
# set the titles of the subplots
|
|
ax[0].set_title("Image")
|
|
ax[1].set_title("Original Mask")
|
|
ax[2].set_title("Predicted Mask")
|
|
|
|
# set the layout of the figure and display it
|
|
figure.tight_layout()
|
|
figure.show()
|
|
|
|
|
|
def make_predictions(model, imagePath):
|
|
# set model to evaluation mode
|
|
model.eval()
|
|
|
|
# turn off gradient tracking
|
|
with torch.no_grad():
|
|
# load the image from disk, swap its color channels, cast it
|
|
# to float data type, and scale its pixel values
|
|
image = cv2.imread(imagePath)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = image.astype("float32") / 255.0
|
|
|
|
# resize the image and make a copy of it for visualization
|
|
image = cv2.resize(image, (128, 128))
|
|
orig = image.copy()
|
|
|
|
# find the filename and generate the path to ground truth
|
|
# mask
|
|
filename = imagePath.split(os.path.sep)[-1]
|
|
groundTruthPath = os.path.join(config.MASK_DATASET_PATH,
|
|
filename)
|
|
|
|
# load the ground-truth segmentation mask in grayscale mode
|
|
# and resize it
|
|
gtMask = cv2.imread(groundTruthPath, 0)
|
|
gtMask = cv2.resize(gtMask, (config.INPUT_IMAGE_HEIGHT,
|
|
config.INPUT_IMAGE_HEIGHT))
|
|
|
|
# make the channel axis to be the leading one, add a batch
|
|
# dimension, create a PyTorch tensor, and flash it to the
|
|
# current device
|
|
image = np.transpose(image, (2, 0, 1))
|
|
image = np.expand_dims(image, 0)
|
|
image = torch.from_numpy(image).to(config.DEVICE)
|
|
|
|
# make the prediction, pass the results through the sigmoid
|
|
# function, and convert the result to a NumPy array
|
|
predMask = model(image).squeeze()
|
|
predMask = torch.sigmoid(predMask)
|
|
predMask = predMask.cpu().numpy()
|
|
|
|
# filter out the weak predictions and convert them to integers
|
|
predMask = (predMask > config.THRESHOLD) * 255
|
|
predMask = predMask.astype(np.uint8)
|
|
|
|
# prepare a plot for visualization
|
|
prepare_plot(orig, gtMask, predMask)
|
|
|
|
# load the image paths in our testing file and randomly select 10
|
|
# image paths
|
|
print("[INFO] loading up test image paths...")
|
|
imagePaths = open(config.TEST_PATHS).read().strip().split("\n")
|
|
imagePaths = np.random.choice(imagePaths, size=10)
|
|
# load our model from disk and flash it to the current device
|
|
print("[INFO] load up model...")
|
|
unet = torch.load(config.MODEL_PATH).to(config.DEVICE)
|
|
# iterate over the randomly selected test image paths
|
|
for path in imagePaths:
|
|
# make predictions and visualize the results
|
|
make_predictions(unet, path) |