How to Analyze Classification Model Predictions: 2-Class vs. 3-Class

Analyzing the predictions of your classification models is crucial for understanding their strengths and weaknesses, regardless of whether you’re working with binary (2-class) or multi-class (3-class or more) problems. This guide walks you through code examples to effectively visualize and interpret your model’s performance, helping you to compare outcomes in different classification scenarios.

Let’s start by assuming you have prediction logs (log_preds) and true labels (y) from your validation set, perhaps obtained using libraries like fastai:

log_preds = learn.predict()
y = data.val_y

or using Test Time Augmentation for potentially more robust predictions:

log_preds, y = learn.TTA()

Now, to delve deeper into the analysis, consider the following Python functions built upon libraries like NumPy and Matplotlib, designed to dissect your classification results.

import numpy as np
import matplotlib.pyplot as plt
import PIL

def plot_val_with_title(idxs, title, y, data, probs):
    "Plots images with titles from validation set based on given indexes."
    imgs = np.stack([data.val_ds[x][0] for x in idxs])
    if type(y) == int: # For single class titles
        title_probs = [f"{data.classes[y]}: {probs[x,y]:.4f}" for x in idxs] # Include class name in title
    else: # For uncertain predictions, use corresponding class index
        title_probs = [f"{data.classes[y[i]]}: {probs[x,y[i]]:.4f}" for i, x in enumerate(idxs)] # Corrected indexing
    print(title)
    return plots(data.val_ds.denorm(imgs), rows=1, titles=title_probs)

def plots(ims, figsize=(12,6), rows=1, titles=None):
    "Helper function to plot images in a grid."
    f = plt.figure(figsize=figsize)
    for i in range(len(ims)):
        sp = f.add_subplot(rows, len(ims)//rows, i+1)
        sp.axis('Off')
        if titles is not None:
            sp.set_title(titles[i], fontsize=10) # Reduced fontsize for titles
        plt.imshow(ims[i])
    plt.show() # Ensure plot is displayed

def load_img_id(ds, idx, PATH):
    "Loads an image by index from a dataset."
    return np.array(PIL.Image.open(PATH+ds.fnames[idx]))

def most_by_mask(mask, y, mult, probs, preds, data):
    "Returns indexes sorted by probability based on a boolean mask."
    idxs = np.where(mask)[0]
    return idxs[np.argsort(mult * probs[idxs,y])[:4]] # Returns top 4 indexes

def most_by_correct(y, is_correct, probs, preds, data):
    "Finds most correct or incorrect predictions for a given class."
    mult = -1 if is_correct else 1 # Correct predictions: descending order, Incorrect: ascending
    return most_by_mask((preds == data.val_y) == is_correct & (data.val_y == y), y, mult, probs, preds, data)

To utilize these functions, you first need to process your log_preds to get class predictions (preds) and probabilities (probs). The number of classes (num_classes) is derived from your dataset.

num_classes = len(data.classes)
preds = np.argmax(log_preds, axis=1) # Get class predictions
probs = np.exp(log_preds) # Convert log probabilities to probabilities

Analyzing Uncertainty in Predictions

One way to compare model behavior, especially between different numbers of classes, is to examine prediction uncertainty. In a 2-class scenario, high uncertainty might mean the model is struggling to distinguish between the two classes. In a 3-class (or multi-class) scenario, uncertainty could arise from confusion between multiple classes.

The following code identifies and visualizes the most uncertain predictions:

most_uncertain = np.argsort(np.average(np.abs(probs-(1/num_classes)), axis = 1))[:4] # Indices of most uncertain predictions
idxs_col = np.argsort(np.abs(probs[most_uncertain,:]-(1/num_classes)))[:4,-1] # Class indices for uncertain predictions
plot_val_with_title(most_uncertain, "Most uncertain predictions", idxs_col, data, probs)

This code calculates uncertainty by looking at how far probabilities are from a uniform distribution (1/num_classes). Visualizing these cases helps identify patterns where the model is indecisive.

Examining Correct and Incorrect Predictions by Class

To understand class-specific performance, particularly when comparing 2-class vs. 3-class models, it’s vital to see which classes are predicted correctly and incorrectly most often.

For example, to see the most correct predictions for class 0:

label = 0 # Class index
plot_val_with_title(most_by_correct(label, True, probs, preds, data), f"Most correct class {data.classes[label]}", label, data, probs)

And to see the most incorrect predictions for class 2:

label = 2 # Class index
plot_val_with_title(most_by_correct(label, False, probs, preds, data), f"Most incorrect class {data.classes[label]}", label, data, probs)

By visualizing these correct and incorrect examples for each class in both 2-class and 3-class models, you can gain insights into:

  • Class separability: Are certain classes inherently harder to distinguish? This might be more pronounced in multi-class problems.
  • Confusion patterns: In 3-class problems, is the model confusing class A with B, or A with C, or both? This level of detail is not present in binary classification.
  • Data imbalances: Are incorrect predictions concentrated in minority classes? This is a common issue in both binary and multi-class but can be more complex to diagnose in multi-class settings.

Comparison Considerations:

When comparing 2-class and 3-class classification, remember that evaluation metrics and interpretation can differ:

  • Baseline Accuracy: In a balanced 2-class problem, random guessing yields 50% accuracy. In a balanced 3-class problem, it’s ~33%. Direct accuracy comparison needs to consider this.
  • Confusion Matrix: Essential for multi-class to see where mistakes are happening. In 2-class, it’s simpler (True/False Positives/Negatives).
  • Precision/Recall/F1-score: For 2-class, you often focus on the positive class. For 3-class, you might need macro/micro averaging to get a single performance metric across all classes, especially if classes are imbalanced.

By using the provided code and visualization techniques, you can move beyond simple accuracy scores and develop a deeper understanding of how your classification models perform, and how their performance characteristics differ when moving from binary to multi-class scenarios. This detailed analysis is key to iteratively improving your models and ensuring they meet the specific demands of your classification task.

Comments

No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *