>  기사  >  백엔드 개발  >  Python의 혼동 행렬 트릭

Python의 혼동 행렬 트릭

WBOY
WBOY원래의
2023-06-11 10:43:542523검색

머신러닝과 데이터마이닝의 인기로 데이터 처리 및 분석을 위해 고급 프로그래밍 언어인 Python을 사용하는 데이터 과학자와 연구자가 늘어나고 있으며, Python의 직관성과 사용 용이성으로 딥러닝 분야에서 인기를 끌고 있습니다. 인공지능 분야에서 널리 활용되고 있습니다. 그러나 많은 초보자들이 Python을 사용할 때 몇 가지 어려움을 겪는데, 그 중 하나는 혼동 행렬의 어려움입니다. 이 기사에서는 Python에서 혼동 행렬을 사용하는 방법과 혼동 행렬을 다룰 때 유용한 몇 가지 기술을 소개합니다.

1. 혼동행렬이란 무엇인가요

딥러닝과 데이터 마이닝에서 혼동행렬은 예측된 결과와 실제 결과의 차이를 비교하는 데 사용되는 직사각형 테이블입니다. 이 매트릭스는 분류 알고리즘의 정확도, 오류율, 정밀도 및 재현율과 같은 중요한 지표를 포함하여 분류 알고리즘의 성능을 보여줍니다. 혼동행렬은 일반적으로 분류기의 성능을 시각화하고 분류기의 개선 및 최적화를 위한 예측 결과에 대한 주요 참조를 제공합니다.

일반적으로 혼동 행렬은 다음 네 가지 매개변수로 구성됩니다.

  • 진정성(TP): 분류 알고리즘은 양성 클래스를 양성 클래스로 정확하게 예측합니다.
  • False Negative(FN): 분류 알고리즘이 양성 클래스를 음성 클래스로 잘못 예측합니다.
  • False Positive(FP): 분류 알고리즘이 음성 클래스를 양성 클래스로 잘못 예측합니다.
  • 트루 네거티브(TN): 분류 알고리즘은 네거티브 클래스를 네거티브 클래스로 정확하게 예측합니다.

2. 혼동행렬 계산 방법

파이썬의 scikit-learn 라이브러리는 혼동행렬을 계산하는 편리한 기능을 제공합니다. Confusion_matrix()라고 불리는 이 함수는 분류기와 테스트 세트의 실제 결과 사이의 입력으로 사용될 수 있으며, 혼동행렬의 매개변수 값을 반환합니다. 이 함수의 구문은 다음과 같습니다.

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

여기서, y_true는 분류기의 올바른 결과를 나타내고, y_pred는 분류기의 예측 결과를 나타내며, labels는 클래스 레이블의 이름을 나타냅니다(제공되지 않은 경우 기본값은 y_true 및 y_pred에서 추출된 값), Sample_weight는 각 샘플의 가중치를 나타냅니다(필요하지 않은 경우 이 매개변수를 설정하지 않음).

예를 들어, 다음 데이터의 혼동 행렬을 계산해야 한다고 가정합니다.

y_true = [1, 0, 1, 2, 0, 1]
y_pred = [1, 0, 2, 1, 0, 2]

혼동 행렬을 계산하려면 다음 코드를 사용할 수 있습니다.

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true, y_pred)
print(cm)

출력 결과는 다음과 같습니다.

array([[2, 0, 0],
       [0, 1, 2],
       [0, 1, 0]])

즉, 혼동 행렬은 "1"이 "1"로 올바르게 분류된 경우가 2개, "0"이 "0"으로 올바르게 분류된 경우가 1개, "2"가 "2"로 올바르게 분류된 경우가 0개, 0개의 경우가 있음을 보여줍니다. "1"이 "2"로 잘못 분류된 경우 "는 "1"로 두 번 오분류되었고, "2"는 "1"로 한 번 오분류되었으며, "0"은 "2"로 한 번 오분류되었습니다.

3. 혼동 행렬 표시

정확한 혼동 행렬 시각화가 필요한 상황이 많이 있습니다. Python의 matplotlib 라이브러리는 혼동 행렬을 시각화할 수 있습니다. 다음은 matplotlib 라이브러리와 sklearn.metrics를 사용하여 혼동 행렬을 시각화하는 Python 코드입니다.

import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

위 코드에서는 혼동 행렬의 매개변수, 카테고리 이름의 텍스트 문자열을 매개변수로 받아들이고 혼동 행렬을 컬러 이미지로 출력하는 플롯_confusion_matrix()라는 사용자 정의 함수를 정의합니다. 행렬 각 셀의 색상은 해당 값의 크기를 나타냅니다. 다음으로, 각각의 실제 범주와 예측 범주를 사용하여 혼동행렬을 계산하고 위에서 정의한plot_confusion_matrix() 함수를 사용하여 혼동행렬을 표현해야 합니다.

4. 요약

Python 언어는 데이터 과학자와 연구자가 딥 러닝과 인공 지능 데이터 분석을 더 빠르게 수행할 수 있는 수많은 시각화 및 데이터 분석 라이브러리를 제공합니다. 이 기사에서는 혼동 행렬과 그 응용 프로그램을 소개하고, Python에서 혼동 행렬을 계산하는 방법과 matplotlib 라이브러리를 사용하여 혼동 행렬의 그래픽을 생성하는 방법을 소개합니다. 혼동행렬 기술은 딥러닝과 인공지능 분야에서 중요한 응용분야를 갖고 있으므로, 혼동행렬 기술을 배우는 것이 매우 필요합니다.

위 내용은 Python의 혼동 행렬 트릭의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.