首頁  >  文章  >  後端開發  >  Python中的混淆矩陣技巧

Python中的混淆矩陣技巧

WBOY
WBOY原創
2023-06-11 10:43:542459瀏覽

隨著機器學習和資料探勘的流行,越來越多的資料科學家和研究人員開始使用Python這種高階程式語言來處理和分析數據,並且Python的直覺性和易用性使其在深度學習和人工智慧的領域中已廣泛應用。然而,許多初學者在使用Python時遇到了一些困難,其中之一就是混淆矩陣的難題。在本文中,我們將介紹Python中混淆矩陣的使用方法以及一些處理混淆矩陣時有用的技巧。

一、什麼是混淆矩陣

在深度學習和資料探勘中,混淆矩陣是一種矩形表格,用於比較預測結果和實際結果之間的差異。此矩陣顯示了分類演算法的效能,包括分類演算法的準確性、錯誤率、精確度和召回率等重要指標。混淆矩陣通常使分類器的性能可視化,並為分類器的改進和優化提供預測結果的主要參考。

通常情況下,混淆矩陣由四個參數組成:

  • 真陽性(TP):分類演算法正確地將正類預測為正類。
  • 假陰性(FN):分類演算法錯誤地將正類預測為負類。
  • 假陽性(FP):分類演算法錯誤地將負類預測為正類。
  • 真陰性(TN):分類演算法正確地將負類預測為負類。

二、如何計算混淆矩陣

Python中的scikit-learn函式庫提供了一個方便的函數來計算混淆矩陣。此函數稱為confusion_matrix(),可以作為分類器和測試集的真實結果之間的輸入,並傳回混淆矩陣的參數值。此函數地語法如下:

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

其中,y_true表示分類器的正確結果,y_pred表示分類器的預測結果,labels表示類別標籤的名稱(如果不提供,則預設為從y_true和y_pred中提取的值),sample_weight表示每個樣本的權重(如果不需要,則不用設定該參數)。

例如,假設我們需要計算以下資料的混淆矩陣:

y_true = [1, 0, 1, 2, 0, 1]
y_pred = [1, 0, 2, 1, 0, 2]

為了計算混淆矩陣,可以使用以下程式碼:

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true, y_pred)
print(cm)

輸出結果為:

array([[2, 0, 0],
       [0, 1, 2],
       [0, 1, 0]])

即此混淆矩陣顯示出「1」被正確分類為「1」的情況有2次,「0」被正確分類為「0」的情況有1次,「2」被正確分類為「2 」的情況有0次,「1」被錯誤分類為「2」的情況有2次,「2」被錯誤分類為「1」的情況有1次,「0」被錯誤分類為「2」的情況有1次。

三、展示混淆矩陣

有許多情況下,我們需要更好的視覺化混淆矩陣。 Python中的matplotlib函式庫可以使混淆矩陣視覺化。下面是的Python程式碼,它使用了matplotlib函式庫和sklearn.metrics來實現混淆矩陣的可視化。

import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

以上程式碼中,我們定義了一個名為plot_confusion_matrix()的自訂函數,該函數作為參數接受混淆矩陣的參數,類別名稱的文字字串,將混淆矩陣作為彩色圖像輸出,其中混淆矩陣的每個單元格的顏色表示其值的大小。接下來,我們需要使用各自的真實類別和預測類別來計算混淆矩陣,並使用在上面定義的plot_confusion_matrix()函數來將混淆矩陣表示出來。

四、小結

Python語言提供了大量的視覺化和資料分析函式庫,可以讓資料科學家和研究人員更快速地進行深度學習和人工智慧的資料分析。在本文中,我們介紹了混淆矩陣及其應用,以及Python中如何計算混淆矩陣和如何使用matplotlib函式庫來產生混淆矩陣的圖形。混淆矩陣技術在深度學習和人工智慧領域中有著重要的應用,因此,學習混淆矩陣技術是非常必要的。

以上是Python中的混淆矩陣技巧的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn