Keras的categorical_accuracy
和sparse_categorical_accuracy
之间有什么区别? documentation for these metrics中没有任何暗示,并且通过询问谷歌博士,我也没有找到答案。
可以找到源代码here:
def categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)),
K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.max(y_true, axis=-1),
K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
K.floatx())
答案 0 :(得分:39)
因此,在categorical_accuracy
中,您需要将目标(y
)指定为单热编码向量(例如,在3个类的情况下,当真正的类是第二类时,y
应该是(0, 1, 0)
。在sparse_categorical_accuracy
中,你需要只提供一个真实类的整数(在前一个例子的情况下 - 它将是1
,因为类索引是0
基)。
答案 1 :(得分:24)
查看source
def categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)),
K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.max(y_true, axis=-1),
K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
K.floatx())
categorical_accuracy
检查最大真值的 index 是否等于最大预测值的 index 。
sparse_categorical_accuracy
检查最大真值是否等于最大预测值的 index 。
根据Marcin的答案,categorical_accuracy
对应one-hot
的{{1}}编码向量。
答案 2 :(得分:3)
sparse_categorical_accuracy
需要稀疏标签:[[0], [1], [2]]
例如:
import tensorflow as tf
sparse = tf.convert_to_tensor([[0], [1], [2]])
logits = tf.convert_to_tensor([[.8, .1, .1], [.5, .3, .2], [.2, .2, .6]])
sparse_cat_acc = tf.metrics.SparseCategoricalAccuracy()
sparse_cat_acc(sparse, logits)
<tf.Tensor: shape=(), dtype=float64, numpy=0.6666666666666666>
categorical_accuracy
需要一个热编码输入:[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]
例如:
onehot = tf.convert_to_tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
logits = tf.convert_to_tensor([[.8, .1, .1], [.5, .3, .2], [.2, .2, .6]])
cat_acc = tf.metrics.CategoricalAccuracy()
cat_acc(sparse, logits)
<tf.Tensor: shape=(), dtype=float64, numpy=0.6666666666666666>
答案 3 :(得分:0)
我刚刚发现的一个区别是指标名称的不同。
使用 categorical_accuracy
,这有效:
mcp_save_acc = ModelCheckpoint('model_' + 'val_acc{val_accuracy:.3f}.hdf5', save_best_only=True, monitor='val_accuracy', mode='max')
但是在切换到 sparse_categorical accuracy
之后,我现在需要这个:
mcp_save_acc = ModelCheckpoint('model_' + 'val_acc{val_sparse_categorical_accuracy:.3f}.hdf5', save_best_only=True, monitor='val_sparse_categorical_accuracy', mode='max')
即使我仍然有 metrics=['accuracy']
作为我的 compile()
函数的参数。
我有点希望 val_acc
和/或 val_accuracy
只是为所有 keras 的内置 *_crossentropy
损失工作。