71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import torch
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from torchvision import transforms
|
|
import cv2
|
|
from PIL import Image
|
|
from pyimagesearch import config
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
|
|
# Define the U-Net model architecture (similar to what's in the article)
|
|
# Ensure to load the model and its weights as per your training script
|
|
|
|
# Function to preprocess the input image
|
|
def preprocess_image(image_path):
|
|
image = cv2.imread(image_path)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
# Define your preprocessing steps (resize, normalize, etc.)
|
|
transform = transforms.Compose([transforms.ToPILImage(),
|
|
transforms.Resize((config.INPUT_IMAGE_HEIGHT,
|
|
config.INPUT_IMAGE_WIDTH)),
|
|
transforms.ToTensor()])
|
|
image = transform(image)
|
|
image = image.unsqueeze(0) # Add batch dimension
|
|
return image
|
|
|
|
# Function to predict the mask using the trained U-Net model
|
|
def predict_mask(model, image):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
output = model(image)
|
|
predicted_mask = torch.argmax(output, dim=1).squeeze(0)
|
|
print(predicted_mask)
|
|
return predicted_mask.numpy()
|
|
|
|
# Function to display the original image and its predicted mask
|
|
def display_image_with_mask(image_path, mask):
|
|
image = Image.open(image_path)
|
|
plt.figure(figsize=(10, 5))
|
|
plt.subplot(1, 2, 1)
|
|
plt.imshow(image)
|
|
plt.title('Original Image')
|
|
plt.axis('off')
|
|
|
|
plt.subplot(1, 2, 2)
|
|
print(mask.shape)
|
|
plt.imshow(mask)
|
|
plt.title('Predicted Mask')
|
|
plt.axis('off')
|
|
|
|
plt.show()
|
|
|
|
# Main function to run the script
|
|
def main():
|
|
# Load your trained model
|
|
model = torch.load(config.MODEL_PATH).to(config.DEVICE)
|
|
|
|
# Accept user input for image path
|
|
image_path = input("Enter path to the image: ")
|
|
|
|
# Preprocess the input image
|
|
image = preprocess_image(image_path)
|
|
# predicted_mask = Image.open("c:\\Users\\anon\\Desktop\\Summer Research-COPY\\unet\\dataset\\train\\masks\\0aab0afa9c.png").convert("RGB")
|
|
# Predict the mask
|
|
predicted_mask = predict_mask(model, image)
|
|
|
|
# Display the image with its predicted mask
|
|
display_image_with_mask(image_path, predicted_mask)
|
|
|
|
main() |