Keras中训练模型的预测功能的问题

时间:2019-04-04 06:13:16

标签: python tensorflow keras neural-network conv-neural-network

我正在对某些图像集执行分类问题,其中我的类别数是3。现在,由于我正在执行CNN,因此它具有卷积层和Pooling层,然后是几个密集层;模型参数如下所示:

def baseline_model():
    model = Sequential()
    model.add(Conv2D(32, (5, 5), input_shape=(1, 100, 100), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(60, activation='relu'))
    model.add(Dropout(0.2))    
    model.add(Dense(num_classes, activation='softmax'))

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

该模型运行完美,并向我显示了准确性和验证错误等。如下所示:

model = baseline_model()
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=5, batch_size=20, verbose=1)
scores = model.evaluate(X_test, y_test, verbose=0)
print("CNN Error: %.2f%%" % (100-scores[1]*100))

哪个给我输出:

Train on 514 samples, validate on 129 samples
Epoch 1/5
514/514 [==============================] - 23s 44ms/step - loss: 1.2731 - acc: 0.4202 - val_loss: 1.0349 - val_acc: 0.4419
Epoch 2/5
514/514 [==============================] - 18s 34ms/step - loss: 1.0172 - acc: 0.4416 - val_loss: 1.0292 - val_acc: 0.4884
Epoch 3/5
514/514 [==============================] - 17s 34ms/step - loss: 0.9368 - acc: 0.5817 - val_loss: 0.9915 - val_acc: 0.4806
Epoch 4/5
514/514 [==============================] - 18s 34ms/step - loss: 0.7367 - acc: 0.7101 - val_loss: 0.9973 - val_acc: 0.4961
Epoch 5/5
514/514 [==============================] - 17s 32ms/step - loss: 0.4587 - acc: 0.8521 - val_loss: 1.2328 - val_acc: 0.5039
CNN Error: 49.61%

问题发生在预测部分。 因此对于我的测试图像,我需要为其预测;当我运行model.predict()时,它给了我这个错误:

TypeError: data type not understood

如果需要,我可以显示完整的错误。 只是为了展示,我训练图像的形状以及我最终用来预测的图像:

X_train.shape
(514, 1, 100, 100)

final.shape
(277, 1, 100, 100)

所以我不知道此错误是什么意思,这是什么问题。甚至我图像值的数据类型都是相同的'float32'。因此形状相同且数据类型相同,那么为什么会出现此问题?

1 个答案:

答案 0 :(得分:1)

类似于predict with Keras fails due to faulty environment setup 我在anaconda和python 3.7中遇到了相同的问题。我改用WPy-3670时解决了 使用python 3.6,一切都降级了。