Setting Up Image Segmentation on Ubuntu using PyTorch

This guide provides step-by-step instructions on how to set up an Image Segmentation system on Ubuntu using PyTorch and pre-trained models from torchvision. Image segmentation involves labeling each pixel in an image, often for identifying different objects or regions within the image.

1. Install System Prerequisites

Start by updating your Ubuntu system and installing necessary dependencies. Open a terminal and run the following commands:

sudo apt update
sudo apt upgrade
sudo apt install python3 python3-pip git
    

2. Install PyTorch and Torchvision

Install PyTorch and Torchvision, which are required for deep learning and image segmentation tasks. If you have a CUDA-compatible GPU, you can install the version with CUDA support. Otherwise, install the CPU version:

pip install torch torchvision torchaudio
    

3. Install Additional Libraries

Install the following additional libraries that will help with image processing and visualization:

pip install opencv-python matplotlib
    

These libraries will help load images and display segmentation results using OpenCV and Matplotlib.

4. Download Pre-trained Image Segmentation Model

We will use a pre-trained model for image segmentation. PyTorch provides models like DeepLabV3, which is pre-trained on the COCO dataset and can be used for segmentation tasks:

from torchvision.models.segmentation import deeplabv3_resnet50
    

5. Create Python Script for Image Segmentation

Create a new Python script named image_segmentation.py that will load the pre-trained DeepLabV3 model and perform segmentation:

nano image_segmentation.py
    

Paste the following code into the script:

import torch
import torchvision
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F

# Load the pre-trained DeepLabV3 model
def load_model():
    model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
    model.eval()  # Set model to evaluation mode
    return model

# Preprocess the image for the model
def preprocess_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_tensor = F.to_tensor(image).unsqueeze(0)  # Convert to tensor and add batch dimension
    return image, image_tensor

# Perform image segmentation
def segment_image(model, image_tensor):
    with torch.no_grad():
        output = model(image_tensor)['out'][0]  # Get the output tensor
    return output.argmax(0).byte().cpu().numpy()  # Get the segmentation map

# Plot segmentation results
def display_segmentation(original_image, segmented_image):
    plt.figure(figsize=(10, 5))

    # Display original image
    plt.subplot(1, 2, 1)
    plt.imshow(original_image)
    plt.title('Original Image')

    # Display segmentation map
    plt.subplot(1, 2, 2)
    plt.imshow(segmented_image)
    plt.title('Segmented Image')

    plt.show()

# Main function
if __name__ == "__main__":
    model = load_model()
    
    image_path = "path/to/your/image.jpg"  # Replace with the path to your image
    original_image, image_tensor = preprocess_image(image_path)

    segmented_image = segment_image(model, image_tensor)
    
    display_segmentation(original_image, segmented_image)
    

This script performs the following steps:

  • Loads a pre-trained DeepLabV3 segmentation model from PyTorch's torchvision library.
  • Preprocesses the input image to match the model's input size and format.
  • Performs image segmentation and returns a segmentation map, where each pixel is assigned a class label.
  • Visualizes both the original image and the segmented image using Matplotlib.

6. Download Pre-trained Model Weights

When you run the script for the first time, the pre-trained DeepLabV3 model weights will be automatically downloaded. Make sure you have an active internet connection during the first execution.

7. Run the Image Segmentation Script

Once the script is ready, run the following command in your terminal to perform image segmentation:

python3 image_segmentation.py
    

Make sure to replace path/to/your/image.jpg with the actual path to the image you want to segment. The script will display the original image alongside the segmented output.

8. Adjust Segmentation Output

The output segmentation map is a single-channel image, where each pixel is assigned a class label. To visualize specific classes or adjust the segmentation results, you can modify the plotting code in the script.

segmented_image == 15  # Example: Show only class label 15
    

This will mask all pixels that do not belong to class 15, allowing you to focus on specific objects or regions in the image.

9. Troubleshooting

If you encounter any issues, consider the following:

  • Ensure that all required Python libraries are installed.
  • Verify that the image path is correct and the file exists.
  • Ensure you have sufficient RAM or GPU memory for loading large models like DeepLabV3.

10. Conclusion

Congratulations! You have successfully set up an Image Segmentation system on Ubuntu using PyTorch and the pre-trained DeepLabV3 model. This system can now segment images by identifying and classifying each pixel based on pre-trained categories.