Keras提供categorical_accuracy
,sparse_categorical_accuracy
和top_k_categorical_accuracy
指标。
我想定义top_k_categorical_accuracy
的稀疏版本。
我的暂定实施如下:
def top_k_sparse_categorical_accuracy(y_true, y_pred, z=5):
return K.mean(K.in_top_k(y_pred, K.max(y_true, axis=-1), z), axis=-1)
这是在top_k_categorical_accuracy
和sparse_categorical_accuracy
def sparse_categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.max(y_true, axis=-1),
K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
K.floatx())
def top_k_categorical_accuracy(y_true, y_pred, k=5):
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
但是,这不起作用,因为K.in_top_k
的第二个参数应该是一个整数的张量,这里我有一个浮点张量,K.cast
只能转向浮点数。
有没有办法在Keras后端执行强制转换?或者是否有另一种方法来实现此指标?