computer_vision_for_biology/test1.py

71 lines
2.2 KiB
Python

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import cv2
from PIL import Image
from pyimagesearch import config
import matplotlib.pyplot as plt
import numpy as np
import torch
# Define the U-Net model architecture (similar to what's in the article)
# Ensure to load the model and its weights as per your training script
# Function to preprocess the input image
def preprocess_image(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Define your preprocessing steps (resize, normalize, etc.)
transform = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((config.INPUT_IMAGE_HEIGHT,
config.INPUT_IMAGE_WIDTH)),
transforms.ToTensor()])
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
return image
# Function to predict the mask using the trained U-Net model
def predict_mask(model, image):
model.eval()
with torch.no_grad():
output = model(image)
predicted_mask = torch.argmax(output, dim=1).squeeze(0)
print(predicted_mask)
return predicted_mask.numpy()
# Function to display the original image and its predicted mask
def display_image_with_mask(image_path, mask):
image = Image.open(image_path)
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
print(mask.shape)
plt.imshow(mask)
plt.title('Predicted Mask')
plt.axis('off')
plt.show()
# Main function to run the script
def main():
# Load your trained model
model = torch.load(config.MODEL_PATH).to(config.DEVICE)
# Accept user input for image path
image_path = input("Enter path to the image: ")
# Preprocess the input image
image = preprocess_image(image_path)
# predicted_mask = Image.open("c:\\Users\\anon\\Desktop\\Summer Research-COPY\\unet\\dataset\\train\\masks\\0aab0afa9c.png").convert("RGB")
# Predict the mask
predicted_mask = predict_mask(model, image)
# Display the image with its predicted mask
display_image_with_mask(image_path, predicted_mask)
main()