혼동행렬(Confusion Matrix)
- Classification 머신러닝 모델이 제대로 작동을 했는지 혼동을 했는지 알아볼 수 있는 행렬
- 행(row)는 실제 클래스, 열(column)은 예측한 클래스
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from pandas import DataFrame
%matplotlib inline
sns.set(font_scale=2)
1. Basic confusion matrix
- seaborn의 heatmap은 confusion matrix를 시각화 하는 함수
- 색상이 밝으면 높은 숫자, 색상이 어두우면 낮은 숫자를 나타냄
arr = [[5,0,0,0], # when input was A, prediction was all A
[0,10,0,0], # when input was B, prediction was all B
[0,0,15,0], # when input was C, prediction was all C
[0,0,0,5]] # when input was D, prediction was all D
df_cm = DataFrame(arr, index=[i for i in "ABCD"],
columns= [i for i in "ABCD"])
df_cm
A | B | C | D | |
---|---|---|---|---|
A | 5 | 0 | 0 | 0 |
B | 0 | 10 | 0 | 0 |
C | 0 | 0 | 15 | 0 |
D | 0 | 0 | 0 | 5 |
plt .figure(figsize = (7,5))
plt.title('confusion matrix without confusion')
sns.heatmap(df_cm, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10e6e5da0>
2. Confusion matrix with confusion
1) normalization 하지 않은 confusion matrix
normalization 하지 않은 confusion matrix의 문제점
- test input의 분포를 모르기 때문에 어떤 알파벳에 효율적인지 말할 수 없음
arr2 = [[9,1,0,0],
[1,15,3,1],
[5,0,24,1],
[0,4,1,15]]
df_cm2 = DataFrame(arr2, index = [i for i in "ABCD"],
columns = [i for i in "ABCD"])
df_cm2
A | B | C | D | |
---|---|---|---|---|
A | 9 | 1 | 0 | 0 |
B | 1 | 15 | 3 | 1 |
C | 5 | 0 | 24 | 1 |
D | 0 | 4 | 1 | 15 |
plt.figure(figsize=(7,5))
plt.title('confusion matrix without normalization')
sns.heatmap(df_cm2, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10ea21908>
2) normalization 한 confusion matrix
import numpy as np
from numpy import linalg as LA
total = np.sum(arr2, axis=1)
print(total)
[10 20 30 20]
arr3 = arr2/total[:,None]
arr3
array([[0.9 , 0.1 , 0. , 0. ],
[0.05 , 0.75 , 0.15 , 0.05 ],
[0.16666667, 0. , 0.8 , 0.03333333],
[0. , 0.2 , 0.05 , 0.75 ]])
df_cm3 = DataFrame(arr3, index = [i for i in "ABCD"],
columns = [i for i in "ABCD"])
plt.figure(figsize = (7,5))
plt.title('confusion matrix with normalization')
sns.heatmap(df_cm3, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10eb4bb00>
다중 분류 모델 성능 측정(accuracy, f1 score, precision, recall)
1. 성능 지표(Performance measures)
Accuracy: 가장 많이 사용되는 지표, data가 balance할 때 효과적
F1 score: recall과 precision의 값을 가지고 조합평균(harmonic mean)을 낸 값, data가 imbalance할 때 효과적
Precision(confusion matrix에서 column 방향)
- 정밀도, classfier가 예측한 값들 중에서 정말로 예측한 값이 맞는가?
Recall(confusion matrix에서 row 방향)
- 재현율, 클래스가 주어졌을 때 classifier가 잘 예측을 하는가?
1) 개념 정리
- TP(true positive): true를 ture로 잘 예측한 것
- TN(true negative): false를 false로 잘 예측한 것
- FP(false positive): false를 true로 잘 못 예측한 것
- FN(false negative): true를 false로 잘 못 예측한 것
2. 예시: Model 1과 Model 2 중 어느 것을 선택해야 하는가?
1) Balanced data
- Model 1
arr4 = [[10,0,0,0],
[0,5,3,2],
[0,1,8,1],
[0,1,0,9]]
df_cm4 = DataFrame(arr4, index = [i for i in "ABCD"],
columns = [i for i in "ABCD"])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm4, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10eca5da0>
- Model 2
arr5 = [[8,2,0,0],
[1,7,0,2],
[0,0,9,1],
[2,3,0,5]]
df_cm5 = DataFrame(arr5, index = [i for i in "ABCD"],
columns = [i for i in "ABCD"])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm5, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10edd5860>
2) Imbalanced data
- Model 1
arr6 = [[100,80,10,10],
[0,9,0,1],
[0,1,8,1],
[0,1,0,9]]
df_cm6 = DataFrame(arr6, index = [i for i in "ABCD"],
columns = [i for i in "ABCD"])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm6, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10ed8cbe0>
- Model 2
arr7 = [[198,2,0,0],
[7,1,0,2],
[0,8,1,1],
[2,3,4,1]]
df_cm7 = DataFrame(arr7, index = [i for i in "ABCD"],
columns = [i for i in "ABCD"])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm7, annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10ec685c0>
** 참고
'AI&BigData > Basics' 카테고리의 다른 글
[Pandas] 기술통계 계산 2 (0) | 2018.06.03 |
---|---|
[Pandas] 기술통계 계산 1 (0) | 2018.06.03 |
[Pandas] 정렬과 순위 (1) | 2018.05.06 |
[Pandas] 함수 적용과 매핑 (0) | 2018.05.06 |
[Pandas] Operation (0) | 2018.05.06 |
[Pandas] Index 객체, reindex (0) | 2018.05.06 |
[Pandas] DataFrame (0) | 2018.04.30 |
[Pandas] Series 객체 (0) | 2018.04.29 |
[Numpy] 브로드캐스트. 기타활용 (0) | 2018.04.24 |