Matplotlib 未显示训练、测试损失/准确度曲线?

时间:2021-01-19 10:47:38

标签: python tensorflow matplotlib deep-learning conv-neural-network

我正在尝试使用 Matplotlib 构建一个显示准确性和损失曲线的图表,但它不显示曲线,而只是显示图表,其 x-axis 从负值开始,为什么不 0

代码:

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

max_accuracy=0.70
for i in range(10):
    print("Epoch no",i+1)
    history = model.fit(train_data,train_labels, epochs=1, batch_size=32,verbose=1,validation_data=(test_data,test_labels))
    if history.history['val_accuracy'][0]>max_accuracy:
        print("New Best model found above")
        max_accuracy=history.history['val_accuracy'][0]
        model.save('CNN-logo.h5')
        
model=tf.keras.models.load_model('CNN-logo.h5')
[train_loss, train_accuracy] = model.evaluate(train_data, train_labels)
print("Evaluation result on Train Data : Loss = {}, accuracy = {}".format(train_loss, train_accuracy))
[test_loss, test_acc] = model.evaluate(test_data, test_labels)
print("Evaluation result on Test Data : Loss = {}, accuracy = {}".format(test_loss, test_acc))
#Plot the loss curves
plt.figure(figsize=[8,6])
plt.plot(history.history['loss'],'r',linewidth=3.0)
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.title('Loss Curves',fontsize=1)
plt.show()
#Plot the Accuracy Curves
plt.figure(figsize=[8,6]) 
plt.plot(history.history['accuracy'],'r',linewidth=3.0) 
plt.plot(history.history['val_accuracy'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16) 
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves',fontsize=16)
plt.show()

图表
Only displaying graph not curves

2 个答案:

答案 0 :(得分:1)

您正在运行模型 1 个 epoch,因此它只有一个 epoch 的历史记录。要保存最佳模型,您可以使用回调

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath='CNN-logo.h5',
                      monitor='val_accuracy', mode='max', save_best_only=True)
earlystopping_callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

history = model.fit(x, y, epochs=50, batch_size=32, verbose=1, validation_split=0.2,
callbacks=[checkpoint_callback, earlystopping_callback])

model=tf.keras.models.load_model('CNN-logo.h5')
[train_loss, train_accuracy] = model.evaluate(x, y)
print("Evaluation result on Train Data : Loss = {}, accuracy = {}".format(train_loss, train_accuracy))
[test_loss, test_acc] = model.evaluate(x, y)
print("Evaluation result on Test Data : Loss = {}, accuracy = {}".format(test_loss, test_acc))

#Plot the loss curves
plt.figure(figsize=[8,6])
plt.plot(history.history['loss'],'r',linewidth=3.0)
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.title('Loss Curves',fontsize=1)
plt.show()

#Plot the Accuracy Curves
plt.figure(figsize=[8,6]) 
plt.plot(history.history['accuracy'],'r',linewidth=3.0) 
plt.plot(history.history['val_accuracy'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16) 
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves',fontsize=16)
plt.show()

输出:

enter image description here enter image description here

答案 1 :(得分:1)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath='CNN-logo.h5',
                      monitor='val_accuracy', mode='max', save_best_only=True)
earlystopping_callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

history = model.fit(x, y, epochs=50, batch_size=32, verbose=1, validation_split=0.2,
callbacks=[checkpoint_callback, earlystopping_callback])

model=tf.keras.models.load_model('CNN-logo.h5')
[train_loss, train_accuracy] = model.evaluate(x, y)
print("Evaluation result on Train Data : Loss = {}, accuracy = {}".format(train_loss, train_accuracy))
[test_loss, test_acc] = model.evaluate(x, y)
print("Evaluation result on Test Data : Loss = {}, accuracy = {}".format(test_loss, test_acc))

#Plot the loss curves
plt.figure(figsize=[8,6])
plt.plot(history.history['loss'],'r',linewidth=3.0)
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.title('Loss Curves',fontsize=1)
plt.show()

#Plot the Accuracy Curves
plt.figure(figsize=[8,6]) 
plt.plot(history.history['accuracy'],'r',linewidth=3.0) 
plt.plot(history.history['val_accuracy'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16) 
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves',fontsize=16)
plt.show()
相关问题