Permutation Channel Importance
A global interpretability method for understanding which channels in a computer vision model are most important

Practically, images are a 2D grid of pixels. In normal images, each pixel will have three values. We call these the RGB channels. Permutation Channel Importance (PCI) is good for obtaining one insight — which of these channels are important to a model’s predictions.

PCI does this by shuffling (or permuting) the pixels in a given channel. We do this for every image in a dataset. Then, comparing the performance of a trained model, before and after permutation, will reveal whether it uses that channel to make predictions.

This is a relatively simple XAI method. Yet, other methods like occlusion and SHAP will use permutations in a similar way. Understanding PCI will provide a good basis for understanding these more complex methods.

So in this lesson, we will:

  • Explain the role of permutation in computer vision
  • Give the steps of the PCI algorithm
  • Explain the applications of PCI
  • Apply PCI to a coastal image segmentation model

For RGB images, PCI can help explain if colour is an important aspect for prediction. However, we will see the true power of the method comes when applying it to a remote sensing problem. These typically deal with more complex inputs with many channels.

You can find all the code for the results and figures in this article on GitHub. You may also find this video on the topic useful. There is another one later on for the coding section.

Explaining Computer Vision Models with PCI

If you are interested in XAI, then you will find these courses useful:


Permutation in Computer Vision

In computer vision, permutation is when we rearrange the pixels, areas, or features of an image in a random or structured manner. For XAI, the goal is to evaluate the importance of different parts of an image for a model’s prediction. This is done by observing how the model’s performance changes when those parts are permuted.

To be clear, we will always be explaining a model trained on the original data. We compare this model’s performance on the original and permuted data. We do not train a new model on permuted data.

Many XAI methods rely on permutations in some form. Occlusion maps are created by systematically masking squares of pixels. In comparison, SHAP permutes different combinations of sets of pixels. These can be individual pixels or groups called superpixels. This is to estimate their marginal contributions (i.e. Shapley values) to the model’s prediction.

We will discuss both occlusion and SHAP in later lessons. For now, we will be using permutations in a different way.

Permuting a Channel

For PCI, we shuffle every pixel in a channel. You can see this in Figure 1 where the red channel has been shuffled. Using the “Original Image” we can predict what type of plant this is. If the predicted probabilities using the “Permuted Image” change significantly, then the red channel is being used by the model to make the prediction for this instance. Repeating this process for all channels and images in a dataset will tell us which channels are important in general.

Figure 1: permuting the red channel of an RGB image (source: author)
Figure 1: permuting the red channel of an RGB image (source: author)

So, unlike occlusion or SHAP, which can tell us which region of an image is important for an individual instance, PCI will tell us if an entire channel is important across all instances. In other words, PCI is a global interpretability method.

This is a similar insight to Permutation Feature Importance (PFI) for tabular data. It can tell us which model feature is most important in general. In comparison, applying SHAP to tabular data can tell us how each feature has contributed to an individual prediction.

In a later section, we discuss the insights into our model PCI can bring. Before that, it is worth defining the PCI algorithm more formally.

The PCI algorithm

To calculate PCI scores, we start with some initial values/choices:

  • c is the number of channels.
  • n is the number of images.
  • k is the number of repetitions to average over.
  • f(x) is a trained model
  • p is our performance metric

For RGB images, c will be 3. We will see that for other areas like remote sensing, c can be larger. n is usually the size of the validation or training dataset. Lastly, we need k as the permutation process is random. Repeating it will give us a more stable estimate of the PCI scores.

To calculate PCI we start by getting a baseline performance value. To do this we use the n images, before they are permuted, as input into f(x). Using the model’s predictions we can then calculate the value for p. This is our baseline.

Then, for a given channel i, we calculate the PCI score by:

  1. Permutation: For all n images, we randomly shuffle the pixels within an image of that channel.
  2. Prediction: We use the permuted images as input to the trained model to generate predictions.
  3. Performance: We calculate p across the n images using the model’s predictions with the permuted channel.
  4. Importance: We compare the permuted performance to the baseline. This is typically done by subtracting the values to find the decrease in performance.
  5. Repeat: We repeat the steps k times and take the average decrease in performance.

In the end, the average decrease in performance is our PCI score. The precise interpretation will depend on our choice of p. For classification, we can use metrics like accuracy or AUC. For regression, we can use MSE or R-squared. When we apply PCI, we will be dealing with an image segmentation model. In this case, we use average accuracy as our performance metric.

Applications of PCI

As mentioned, PCI can tell us one thing—which channel is most important. The usefulness of this will depend on your problem. For RGB images, we can tell if colour is important or if images could be greyscalled. However, we will see that it becomes more useful when dealing with multispectral images. These include additional channels above the three visible light bands.

RGB images

To understand the insights for RGB images, let’s consider the Pot Plant dataset (CC BY 4.0). This is an image classification problem where we aim to predict which pot plant is in the image. You can see an example from the four classes in Figure 2.

Figure 2: examples from the validation set of the pot plant dataset. The dataset contains images of 4 different house plants. The target (i.e names) for each plant is given above the image. (source: author)
Figure 2: examples from the validation set of the pot plant dataset. The dataset contains images of 4 different house plants. The target (i.e. names) for each plant is given above the image. (source: author)

When we apply PCI to a model trained on this dataset, we get similar scores for each channel. Specifically, the accuracy on the validation set decreased by 0.35, 0.39 and 0.35 percentage points for the Red, Green and Blue channels respectively. This tells us that all channels are being used to make predictions.

Looking at Figure 3, we can gain even more insight. For this plot, we calculated the PCI scores for each class (i.e. pot plant) separately. We can see that a channel’s importance varies. For example, the green channel is most important when making predictions for Rudo but the least important for Greg.

Figure 3: permutation channel importance scores by class. The are calculated by dividing the validation set into 4 groups based on the class label (i.e plant’s name). The PCI for a channel is the accuracy for that class less the accuracy when the channel is permuted. (source: author)
Figure 3: permutation channel importance scores by class. They are calculated by dividing the validation set into 4 groups based on the class label (i.e. plant’s name). The PCI for a channel is the accuracy for that class less the accuracy when the channel is permuted. (source: author)

We see this result because the brightly coloured pots have introduced bias into our dataset. In a later lesson, we will use saliency maps to show that the model is using the pixels from the pots and not the plants to make predictions. The results in Figure 3 provide additional evidence that it is the colour of the pots that are causing errors.

In general, we can use PCI to understand if colour is important to the prediction. If only one channel is used by the model. Then we can simplify the model input by greyscalling or only selecting the channels that are important. This can be useful but these insights become far more valuable when dealing with more complex data sources.

Multi-spectral imagery

The RGB colour we see is reflected radiation or light waves of different wavelengths. In remote sensing, we deal with “images” taken with advanced sensors on satellites or other aircraft. Multispectral images include wavelengths that are visible to the human eye like the near-infrared (NIR) band. We can also deal with microwaves from synthetic aperture radar (SAR) or data quality channels that segment objects like clouds.

These new channels have the same structure as RGB channels — a grid of pixel values. Looking at Figure 4, this means we can use the same deep learning architectures. In this case, it is only the first convolution layer that will need to be adjusted so its kernel is large enough for the additional channels.

Figure 4: the difference in the number of channels in RGB images and multispectral images (source: author).
Figure 4: the difference in the number of channels in RGB images and multispectral images (source: author).

The increased complexity of the input data is why PCI is generally more useful for remote sensing problems. We will see this when we apply the method to a coastal image segmentation model. It allows us to compare the inner workings of the model to more traditional research in this area[2].

Applying PCI with Python

—> go to notebook

Applying Permutation Channel Importance (PCI) to a Remote Sensing Model | Python Tutorial

To apply PCI, we will use the Landsat Irish Coastal Segmentation (LICS) Dataset(CC BY 4.0). It was developed to aid the development of deep learning methods for coastal water body segmentation [1]. Specifically, we will use the LICS test set. This contains 100 multispectral images. As seen in Figure 5, each image has a binary segmentation mask (target). This classifies every pixel as either land (0) or ocean (1). We also have 7 spectral bands available as input.

Figure 5: an example from the LICS test set. Show is the segmentation mask (target) and 7 spectral bands used as input in the model (source: author)
Figure 5: an example from the LICS test set. Show is the segmentation mask (target) and 7 spectral bands used as input in the model (source: author)

Models that have already been trained on LICS are available on Hugging Face. We will use LICS_FINETUNE_26JUL24.pth. This is the model with the highest accuracy on the LICS test set. It was trained using a two-step transfer learning approach (paper out soon). Let’s load the model and data to better understand how we can work with them.

Load data and model

We start with our imports. In the src directory, there is a folder called modelling. We navigate to this folder (line 16) and load the SegmentationDataset class from datasets.py (line 17). This is the same class used to train the model.

# Imports
import numpy as np
import matplotlib.pyplot as plt

import random
import glob
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

from huggingface_hub import hf_hub_download

# Load python files
import sys
sys.path.append('../modelling')
from datasets import SegmentationDataset

We load the paths for the 100 images in the LICS test set (line 2). We then pass these paths into the SegmentationDataset class to make a dataset object (line 5). This allows us to load the data in the correct format for input into the model.

# Load paths
paths = glob.glob('../../data/LICS/test/*')

# Create dataset object
lics_dataset = SegmentationDataset(paths) 

We load our model from the Hugging Face repo (lines 2-7). We then move this model to a GPU (line 13) and set it to evaluation mode (line 14). Now, let’s see how we can use this model to make predictions.

# Download the model directly from Hugging Face
model_path = hf_hub_download(
    repo_id="a-data-odyssey/coastal-image-segmentation", 
    filename="LICS_FINETUNE_26JUL24.pth")

# Load the model
model = torch.load(model_path)

# Set the model to evaluation mode
device = torch.device('mps' if torch.backends.mps.is_built() 
                      else 'cuda' if torch.cuda.is_available() 
                      else 'cpu')
model.to(device)
model.eval()  

We load the first instance of our dataset (line 2). This will include the 7 spectral bands and target. We move the bands to the GPU (line 5), add a dimension for batch size (line 6) and get a prediction from our model (line 7).

# Load first instance
bands, target = lics_dataset.__getitem__(0)

# Make a prediction
input = bands.to(device)
input = input.unsqueeze(0)
output = model(input)

Before visualising, we must consider the shape of the target is (2, 256, 256). This is because we have two classes — land and water. Similarly, the shape of the output is (1, 2, 256, 256). In this case, we have 2D arrays of scores for each of the two classes with the additional dimension for batch size.

print(target.shape) #(2, 256, 256)
print(output.shape) #(1, 2, 256, 256)

We format the target and output using the argmax function. For both, this gives us an array of pixels with dimensions (256, 256). Each pixel will have a value of either 0 or 1. For example, when applied to the output (lines 5-6) a pixel will have a value of 0 if the land score was higher and 1 otherwise.

# Get the water mask 
target = np.argmax(target, axis=0)

# Get the predicted water mask
output = output.cpu().detach().numpy().squeeze()
output = np.argmax(output, axis=0)

Figure 6 shows the formatted target and output. We can see the model gives a good approximation of the coastline but lacks some granular details. The number in brackets gives the accuracy of the prediction. This tells us that 99.7% of the pixels have been correctly classified. Now let’s try to understand how this model is making this prediction.

Figure 6: comparison of the target and prediction. The number in brackets gives the accuracy which is the proportion of correctly classified pixels.
Figure 6: comparison of the target and prediction. The number in brackets gives the accuracy which is the proportion of correctly classified pixels.
# Plot the prediction
fig, axs = plt.subplots(1, 2, figsize=(9, 5))
axs[0].imshow(target, cmap='gray')
axs[0].set_title('Target', fontsize=16)

axs[1].imshow(output, cmap='gray')

accuracy = np.mean(np.array(target_water == output))
accuracy = round(accuracy, 3)
axs[1].set_title(f'Prediction ({accuracy})', fontsize=16)

for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])

Functions for PCI

We have a few functions to help us do this. The first will calculate the various performance metrics. The PCI method can be applied to any problem. This is why we have included the accuracy calculation for classification problems. This function can be adapted to include any relevant metric.

For segmentation, we also calculate accuracy. Keep in mind that each test instance will have its own accuracy value. This is why we take the average accuracy across all the instances (lines 8-15).

def performance_metric(targets, predictions, type='segmentation'):
    """Calculate the performance metric for the model"""

    targets = np.array(targets)
    predictions = np.array(predictions)

    if type == 'segmentation':
        # Calculate average accuracy
        accuracy_list = []
        for t,p in zip(targets, predictions):
            accuracy = np.mean(t == p)
            accuracy_list.append(accuracy)

        metric = np.mean(accuracy_list)
    if type == 'classification':
        # Calculate accuracy
        metric = np.mean(targets == predictions)

   
    return metric

The next function will permute a given channel in an image. To do this, we get the channel (line 11) and flatten it (line 12). We then shuffle the array (line 15), resize it to the original dimensions (line 16) and replace the original channel (line 17). We want to use this image as input into the model so the last step is to format it as a tensor (line 20).

def shuffle_channel(img, channel):
        """Shuffle a given channel of an image
            img: tensor, image to shuffle with shape (batch, channels, height, width)
            channel: int, channel to shuffle"""

        # Get size of the image
        size_x = img.shape[2]
        size_y = img.shape[3]

        # Flatten the channel
        perm_img = img.to('cpu').detach().numpy()
        perm_channel = perm_img[0][channel]
        perm_channel = perm_channel.ravel()

        # Shuffle the channel
        random.shuffle(perm_channel)
        perm_channel.resize(size_x,size_y)
        perm_img[0][channel] = perm_channel

        #Convert to tensor
        perm_img = torch.tensor(perm_img)
        
        return perm_img

The last function will calculate the performance metric for the entire dataset after a given channel has been permuted. You can see we format the data and get predictions in a similar way as before. Expect now we are looping over the entire dataset using a PyTorch DataLoader (line 4). Now let’s see how we can use these functions to plot PCI values.

def get_permuted_performance(model, dataset, channel=-1, type='segmentation'):
    """Calculate the performance metric for the model with permuted data"""

    data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
    model.eval()

    targets = []
    outputs = []

    for image, target in iter(data_loader):

        # Format target
        target = target.numpy().squeeze()
        target = np.argmax(target, axis=0)

        # Permuted image
        if channel != -1:
            image = shuffle_channel(image, channel)

        # Get prediction
        image = image.to(device)
        output = model(image)

        output = output.cpu().detach().numpy().squeeze()
        output = np.argmax(output, axis=0)

        # Append to list
        targets.append(target)
        outputs.append(output)
    
    metric = performance_metric(targets, outputs, type=type)

    return metric

Calculating PCI

You may have noticed that, in the above function, we have the option to calculate the performance on data that is not permuted. That is by setting channel = -1. We use this to calculate a baseline performance score (line 2). This gives us an average accuracy of 98.5% over the 100 images in the LICS test set.

# Get baseline performance 
baseline = get_permuted_performance(model, lics_dataset, channel=-1, type='segmentation')
print(f'Baseline accuracy: {np.round(baseline,3)}')

Now, let’s use the function to calculate PCI when we do permute a channel. In this case, we select channel 3 which is the NIR band. Now, we get an average accuracy of 60.5%. Subtracting this from our baseline gives a PCI score of 38%.

# Get nir performance
nir = get_permuted_performance(model, lics_dataset, channel=3, type='segmentation')
print(f'NIR accuracy: {np.round(nir,3)}')
print(f'PCI: {np.round(baseline - nir,3)}')

Next, we repeat this process for all channels (line 4). We also make sure to repeat the calculation 3 times for each channel (line 7) and then take the average of those values (line 12).

pci_scores = []

# Repeat for all channels
for channel in tqdm([0,1,2,3,4,5,6]):
    # Repeate the calculation 3 times
    metrics = []
    for r in range(3):
        metric = get_permuted_performance(model, lics_dataset, channel=channel, type='segmentation')
        metrics.append(metric)

    # Calculate PCI    
    pci = baseline - np.mean(metrics)
    pci_scores.append(pci)

Finally, we can plot our PCI scores. You can see the output in Figure 7. This suggests that NIR is the most important spectral band for this model. This is followed by SWIR 2 and SWIR 1.

band_names = {0: 'Blue', 1: 'Green', 2: 'Red', 3: 'NIR', 4: 'SWIR1', 5: 'SWIR2', 6: 'Thermal'}

# Plot the PCI scores
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.bar(band_names.values(), pci_scores)

ax.set_ylabel('PCI', fontsize=16, fontweight='bold')
ax.set_xlabel('Spectral Band', fontsize=16, fontweight='bold')
Figure 7: permutation channel importance scores for a coastal image segmentation model (source: author)
Figure 7: permutation channel importance scores for a coastal image segmentation model (source: author)

This is an interesting result. These are all infrared bands and they are commonly used in traditional, deterministic approaches for water body segmentation [2]. In fact, the NIR band is particularly important for this task[3]. In other words, the model uses spectral bands that are consistent with previous research.

This is one of the benefits of this type of research. Through explaining the model, industry professionals are more likely to trust it. This is especially true if the way it works is consistent with their domain knowledge. From a machine learning perspective, these results are also useful. Based on them, we may be able to reduce the number of channels used to train future models.


I hope you enjoyed this article! See the course page for more XAI courses. You can also find me on Threads | YouTube | Medium 


Additional Resources

Permutation Feature Importance from Scratch

Interpreting a Semantic Segmentation Model for Coastline Detection

Python Notebook

Datasets

Conor O’Sullivan, & Soumyabrata Dev. (2024). The Landsat Irish Coastal Segmentation (LICS) Dataset. (CC BY 4.0) https://doi.org/10.5281/zenodo.13742222

Conor O’Sullivan (2024). Pot Plant Dataset. (CC BY 4.0) https://www.kaggle.com/datasets/conorsully1/pot-plants

References

[1] O’Sullivan, C., Kashyap, A., Coveney, S., Monteys, X. and Dev, S., 2024. Enhancing coastal water body segmentation with Landsat Irish Coastal Segmentation (LICS) dataset. Remote Sensing Applications: Society and Environment, p.101276. https://arxiv.org/abs/2409.15311

[2] O’Sullivan, C., Coveney, S., Monteys, X. and Dev, S., 2023, July. Interpreting a semantic segmentation model for coastline detection. In 2023 Photonics & Electromagnetics Research Symposium (PIERS) (pp. 209-215). IEEE. https://arxiv.org/abs/2405.11500

[3] J. P. Mondejar and A. F. Tongco. Near infrared band of landsat 8 as water index: a case study around cordova and lapu-lapu city, cebu, philippines. Sustainable Environment Research, 29:1–15, 2019.


Get the paid version of the course. This will give you access to the course eBook, certificate and all the videos ad free.