Home  >  Article  >  Backend Development  >  Confusion Matrix Tricks in Python

Confusion Matrix Tricks in Python

WBOY
WBOYOriginal
2023-06-11 10:43:542459browse

With the popularity of machine learning and data mining, more and more data scientists and researchers are beginning to use Python, a high-level programming language, to process and analyze data, and Python’s intuitiveness and ease of use make it in-depth It is widely used in the fields of learning and artificial intelligence. However, many beginners encounter some difficulties when using Python, one of which is the difficulty of confusion matrix. In this article, we will introduce the use of confusion matrices in Python and some useful techniques when dealing with confusion matrices.

1. What is a confusion matrix

In deep learning and data mining, a confusion matrix is ​​a rectangular table used to compare the differences between predicted results and actual results. This matrix shows the performance of the classification algorithm, including important indicators such as accuracy, error rate, precision and recall of the classification algorithm. The confusion matrix usually visualizes the performance of the classifier and provides the main reference for the prediction results for the improvement and optimization of the classifier.

Normally, the confusion matrix consists of four parameters:

  • True Positive (TP): The classification algorithm correctly predicts the positive class as a positive class.
  • False Negative (FN): The classification algorithm incorrectly predicts a positive class as a negative class.
  • False Positive (FP): The classification algorithm incorrectly predicts a negative class as a positive class.
  • True Negative (TN): The classification algorithm correctly predicts a negative class as a negative class.

2. How to calculate the confusion matrix

The scikit-learn library in Python provides a convenient function to calculate the confusion matrix. This function is called confusion_matrix() and can be used as input between the classifier and the real results of the test set, and returns the parameter values ​​of the confusion matrix. The syntax of this function is as follows:

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

Among them, y_true represents the correct result of the classifier, y_pred represents the prediction result of the classifier, and labels represents the name of the class label (if not provided, the default is from y_true and y_pred extracted value), sample_weight represents the weight of each sample (if not needed, do not set this parameter).

For example, suppose we need to calculate the confusion matrix of the following data:

y_true = [1, 0, 1, 2, 0, 1]
y_pred = [1, 0, 2, 1, 0, 2]

To calculate the confusion matrix, you can use the following code:

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true, y_pred)
print(cm)

The output result is:

array([[2, 0, 0],
       [0, 1, 2],
       [0, 1, 0]])

That is, the confusion matrix shows that "1" is correctly classified as "1" twice, "0" is correctly classified as "0" once, and "2" is correctly classified as "2" "There are 0 cases of "1" being misclassified as "2", there are 2 cases of "2" being misclassified as "1", and "0" being misclassified as "2" The situation occurred once.

3. Display the confusion matrix

There are many situations where we need a better visualization of the confusion matrix. The matplotlib library in Python can visualize confusion matrices. The following is Python code that uses the matplotlib library and sklearn.metrics to visualize the confusion matrix.

import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

In the above code, we define a custom function named plot_confusion_matrix(), which accepts the parameters of the confusion matrix as parameters, the text string of the category name, and outputs the confusion matrix as a color image. The color of each cell of the confusion matrix represents the size of its value. Next, we need to calculate the confusion matrix using the respective true and predicted categories and represent the confusion matrix using the plot_confusion_matrix() function defined above.

4. Summary

The Python language provides a large number of visualization and data analysis libraries, which can enable data scientists and researchers to conduct deep learning and artificial intelligence data analysis more quickly. In this article, we introduce the confusion matrix and its applications, as well as how to calculate the confusion matrix in Python and how to use the matplotlib library to generate the graphics of the confusion matrix. Confusion matrix technology has important applications in the fields of deep learning and artificial intelligence. Therefore, it is very necessary to learn confusion matrix technology.

The above is the detailed content of Confusion Matrix Tricks in Python. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn