我正在尝试在tf 2.0中创建SparseCategoricalAccuracy
的掩码版本,可以通过compile(metrics=[masked_accuracy_fn()]
传递给Keras api。
函数如下:
def get_masked_acc_metric_fn(ignore_label=-1):
"""Gets the masked accuracy function."""
def masked_acc_fn(y_true, y_pred):
"""Masked accuracy."""
y_true = tf.squeeze(y_true)
# Create mask for time steps we don't care about
mask = tf.not_equal(y_true, ignore_label)
masked_acc = tf.keras.metrics.SparseCategoricalAccuracy(
'test_masked_accuracy', dtype=tf.float32)(y_true, y_pred, sample_weight=mask)
return masked_acc
return masked_acc_fn
这在“急切”模式下有效。但是,在图形模式下运行时,出现错误:
ValueError: tf.function-decorated function tried to create variables on non-first call
答案 0 :(得分:0)
这似乎是一种临时解决方法:
class MaskedSparseCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
def __init__(self, name="masked_sparse_categorical_accuracy", dtype=None):
super(MaskedSparseCategoricalAccuracy, self).__init__(name, dtype=dtype)
def update_state(self, y_true, y_pred, ignore_label=-1):
sample_weight = tf.not_equal(y_true, ignore_label)
super(MaskedSparseCategoricalAccuracy, self).update_state(y_true, y_pred, sample_weight)