我正在使用Keras进行面部表情识别,我有.csv格式的数据集:
Emotion | Pixels | Usage
情感是图像的实际情感(像素列),用法是:培训,验证或测试。< / p>
我的模型生成为hdf5格式。
我想为测试数据生成一个混淆矩阵(Where Usage column = Test),并绘制预测情绪与真实情绪的百分比。
我怎么能用sklearn做到这一点?
以下是我尝试修改以使其适用于我的测试数据的代码示例。
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
x = [[ 5.1, 3.5, 1.4, 0.2],
[4.9, 3., 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5., 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5., 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2]]
y = [0, 0, 0, 0, 0, 0, 2, 2, 1]
emotions = {'Angry','Disgust','Fear','Happy','Neutral','Sad','Surprise'}
X_train, X_test, y_train, y_test = train_test_split(x, y, random_state=0)
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)
def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap=plt.cm.Blues):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks([0, 1, 2, 3, 4, 5 ,6], ['Angry','Disgust','Fear','Happy','Neutral','Sad','Surprise'], rotation=45)
plt.yticks([0, 1, 2, 3, 4, 5 ,6], ['Surprise','Sad','Neutral','Happy','Fear','Disgust', 'Angry'])
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig('Confusion_Matrix.png')
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
plot_confusion_matrix(cnf_matrix, classes=emotions, title='Confusion matrix')
plt.show()
以下是如何访问测试数据的:
x = []
y = []
f = open('My_file.csv')
csv_f = csv.reader(f)
for row in csv_f:
number+= 1
if str(row[2]) == "Test":
temp_list_train = []
for pixel in row[1].split( ):
temp_list_train.append(int(pixel))
y.append(int(row[0]))
x.append(data.reshape(2304).tolist())