我正在使用Keras模型上的自定义损失函数。我的“自定义”损失似乎失败了(就准确性得分而言),即使我只使用返回原始角膜损失的包装器。
作为一个玩具示例,我正在使用"Basic classification" Tensorflow / Keras教程,该教程在fashion-MNIST数据集上使用了一个简单的NN,并且正在关注相关的Keras documentation和{{3} }因此发布。
这是模型:
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
现在,如果我将sparse_categorical_crossentropy
保留为compile()
函数中的字符串参数,则训练结果的准确性约为〜87%:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10)
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('\nTest accuracy:', test_acc)
但是当我创建一个琐碎的包装函数来调用keras的交叉熵时,我在训练和测试集上的准确性都达到了约10%:
from tensorflow.keras import losses
def my_loss(y_true, y_pred):
return losses.sparse_categorical_crossentropy(y_true, y_pred)
model.compile(optimizer='adam',
loss=my_loss,
metrics=['accuracy'])
Epoch 1/10 60000/60000 [=============================]-3s 51us / sample-损失: 0.5030-精度:0.1032
时代2/10 60000/60000 [==============================]-3s 45us / sample-损耗:0.3766-精度:0.1035...
测试准确度:0.1013
通过绘制一些图像并检查其分类标签,看起来每种情况下的结果看起来都不相同,但是打印的准确性却大不相同。那么,是否存在默认指标无法很好地应对自定义损失的情况?我看到的是误差而不是准确性吗?我在文档中缺少什么吗?
编辑:两种情况下损失函数的值最终大致相同,因此确实需要进行训练。准确性是失败的关键。
答案 0 :(得分:0)
原因如下:
当您使用内置损耗并在那时使用loss='sparse_categorical_crossentropy'
时,使用的精度指标为sparse_categorical_accuracy
。但是,当您此时使用自定义损耗函数时,使用的精度指标为categorical_accuracy
。
示例:
model.compile(optimizer='adam',
loss=losses.sparse_categorical_crossentropy,
metrics=['categorical_accuracy', 'sparse_categorical_accuracy'])
model.fit(train_images, train_labels, epochs=1)
'''
Train on 60000 samples
60000/60000 [==============================] - 5s 86us/sample - loss: 0.4955 - categorical_accuracy: 0.1045 - sparse_categorical_accuracy: 0.8255
'''
model.compile(optimizer='adam',
loss=my_loss,
metrics=['accuracy', 'sparse_categorical_accuracy'])
model.fit(train_images, train_labels, epochs=1)
'''
Train on 60000 samples
60000/60000 [==============================] - 5s 87us/sample - loss: 0.4956 - acc: 0.1043 - sparse_categorical_accuracy: 0.8256
'''