创建SparseCategoricalAccuracy的修改版本,获取ValueError:tf.function-decorated函数试图在非首次调用时创建变量

时间:2019-12-15 16:48:23

标签: tensorflow2.0 tf.keras eager-execution

我正在尝试在tf 2.0中创建SparseCategoricalAccuracy的掩码版本,可以通过compile(metrics=[masked_accuracy_fn()]传递给Keras api。

函数如下:

def get_masked_acc_metric_fn(ignore_label=-1):
    """Gets the masked accuracy function."""
    def masked_acc_fn(y_true, y_pred):
        """Masked accuracy."""
        y_true = tf.squeeze(y_true)
        # Create mask for time steps we don't care about
        mask = tf.not_equal(y_true, ignore_label)
        masked_acc = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_masked_accuracy', dtype=tf.float32)(y_true, y_pred, sample_weight=mask)
        return masked_acc

    return masked_acc_fn

这在“急切”模式下有效。但是,在图形模式下运行时,出现错误:

ValueError: tf.function-decorated function tried to create variables on non-first call

1 个答案:

答案 0 :(得分:0)

这似乎是一种临时解决方法:

class MaskedSparseCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
    def __init__(self, name="masked_sparse_categorical_accuracy", dtype=None):
        super(MaskedSparseCategoricalAccuracy, self).__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, ignore_label=-1):
        sample_weight = tf.not_equal(y_true, ignore_label)
        super(MaskedSparseCategoricalAccuracy, self).update_state(y_true, y_pred, sample_weight)