我想评估模型的准确性,但也要实现cifar10数据集的所有10个类的混淆矩阵,因此我收到此错误消息“检查输入时出错:期望conv2d_9_input具有4个维,但得到数组形状(10000,10)“
def run_test_harness():
# load dataset
trainX, trainY, testX, testY = load_dataset()
# prepare pixel data
trainX, testX = prep_pixels(trainX, testX)
# define model
model = define_model()
# fit model
history = model.fit(trainX, trainY, epochs=100, batch_size=64, validation_data=(testX, testY), verbose=0)
# fig
y_pred=model.predict_classes(testY)
con_mat = tf.math.confusion_matrix(labels=y_true, predictions=y_pred).numpy()
con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)
con_mat_df = pd.DataFrame(con_mat_norm, index = classes, columns = classes)
figure = plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
accuracy, precision, recall = model.evaluate(testX, testY, verbose=0)
print ("recall")
print ('> %.3f' % (recall * 100.0))
print ("accuracy")
print('> %.3f' % (accuracy * 100.0))
print ("precision")
print('> %.3f' % (precision * 100.0))
#学习曲线 #结束
accuracy, precision, recall = model.evaluate(testX, testY, verbose=0)
print ("recall")
print ('> %.3f' % (recall * 100.0))
print ("accuracy")
print('> %.3f' % (accuracy * 100.0))
print ("precision")
print('> %.3f' % (precision * 100.0))
这是在Cnn上的实现
def define_model():
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(10, activation='softmax'))
# compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',precision_m, recall_m])
return model
答案 0 :(得分:0)
我的猜测是trainX或testX被某处的y数据交换了。适用于trainX的预期输入形状是:(batchsize,32,32,3),我认为批量大小是10000。trainY可能具有形状(10000,10)。您可能想在调用model.fit之前检查trainX,trainY,testX和testY的形状,以确保没有交换或其他损坏。我希望这会有所帮助。