SoFunction
Updated on 2024-10-28

Pytorch classification model plotting confusion matrix and visualization details

Step 1. Obtain Confusion Matrix

# First define a null confusion matrix of classifications*classifications
 conf_matrix = (Emotion_kinds, Emotion_kinds)
 # Use torch.no_grad() to significantly reduce the GPU footprint of test cases
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     ([50, 3, 200, 200])   
            # targets: ([50, 1]), one more dimension, so we're going to get rid of it
            targets = ()  # [50,1] ----->  [50]

            # Convert variables to gpu
            targets = ()
            imgs = ()
            # print(step,,(),,())
            
            out = model(imgs)
            #Record Confusion Matrix Parameters
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

The confusion matrix is obtained using the confusion_matrix function, which is defined as follows:

def confusion_matrix(preds, labels, conf_matrix):
    preds = (preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

At the end of the execution of test_loader, we can get the confusion matrix of the data, and then we have to calculate the number of correct identifications and visualize the confusion matrix:

conf_matrix=(conf_matrix.cpu())# Transfer obfuscation matrix from gpu to cpu to np
corrects=conf_matrix.diagonal(offset=0)# of correct identifications per classification for extracted diagonals
per_kinds=conf_matrix.sum(axis=1)# of total test strips extracted for each categorized data

 print("Total number of elements in the confusion matrix:{0},Total number of test sets:{1}".format(int((conf_matrix)),test_num))
 print(conf_matrix)

 # Get recognition accuracy for each Emotion
 print("Total number of each emotion:",per_kinds)
 print("Number of correct predictions for each emotion:",corrects)
 print("The recognition accuracy for each emotion is:{0}".format([rate*100 for rate in corrects/per_kinds]))

The output of executing this step is shown below:

Step 2. Confusion Matrix Visualization

Visualizing the confusion matrix obtained above

# Drawing the confusion matrix
Emotion=8#This value is a specific number of categories, you can change it yourself
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']# Labels for each category

# Display data
(conf_matrix, cmap=)

# Quantitative/probabilistic information labeled in the graphs
thresh = conf_matrix.max() / 2	# Numeric color threshold, if the value exceeds this, the color is deepened.
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # Note that matrix[y, x] is not matrix[x, y] here
        info = int(conf_matrix[y, x])
        (x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()# Ensure that the diagrams don't overlap
(range(Emotion_kinds), labels)
(range(Emotion_kinds), labels,rotation=45)# X-axis font is tilted 45°
()
()

Okay, here is the final visualization of the confusion matrix:

Access to other disaggregated indicators

For example, F1 score, TP, TN, FP, FN, precision rate, recall rate, etc., to be added (because it's not used yet).

summarize

To this point this article on pytorch classification model drawing confusion matrix and visualization detailed article is introduced to this, more related pytorch drawing confusion matrix content please search for my previous articles or continue to browse the following related articles I hope you will support me more in the future!