MNIST分类:mean_squared_error损失函数和tanh激活函数

时间:2018-11-19 09:27:56

标签: tensorflow machine-learning keras neural-network classification

我将Tensorflow的getting started example更改如下:

import tensorflow as tf
from sklearn.metrics import roc_auc_score
import numpy as np
import commons as cm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sn

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  # tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  # tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.tanh)
])
model.compile(optimizer='adam',
               loss='mean_squared_error',
              # loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = cm.Histories()
h= model.fit(x_train, y_train, epochs=50, callbacks=[history])
print("history:", history.losses)
cm.plot_history(h)
# cm.plot(history.losses, history.aucs)


test_predictions = model.predict(x_test)


# Compute confusion matrix
pred = np.argmax(test_predictions,axis=1)
pred2 = model.predict_classes(x_test)
confusion = confusion_matrix(y_test, pred)
cm.draw_confusion(confusion,range(10))

其默认参数为:

    在隐藏层激活
  • relu
  • softmax在输出层和
  • sparse_categorical_crossentropy作为损失函数,

工作正常,所有数字的预测都在99%以上

但是,使用我的参数:tanh激活函数和mean_squared_error损失函数,它只是为所有测试样本预测了0

enter image description here

我想知道是什么问题?每个时期的准确率不断提高,达到99%,损失约为20

1 个答案:

答案 0 :(得分:1)

您需要为数据使用适当的丢失功能。在这里,您有一个分类输出,因此您需要使用xAxis,但也要设置xAxis而不需对最后一层进行任何激活。

如果您需要使用chart: { events: { load: function() { let extremes = this.xAxis[0].getExtremes() let min = extremes.dataMin - extremes.dataMin % (1000 * 60 * 60 * 24), max = extremes.dataMax - extremes.dataMax % (1000 * 60 * 60 * 24); this.update({ xAxis: { min: min, max: max }, navigator: { xAxis: { min: min, max: max } } }, true, false, false) } } }, 作为输出,则可以将MSE与标签的一键编码版本+重新缩放一起使用。