在Tensorflow Keras中创建自定义指标类

时间:2019-11-15 11:55:51

标签: tensorflow keras

我想建立一个指标,以计算组级别的精度。

例如假设形状为(batch, 10, 1)的LSTM输出,我想沿时间维度分组(按10个时间戳分组)并计算精度。

我创建了指标,该指标继承了精度:

class PrecisionGrouped(tf.keras.metrics.Precision):
  def __init__(self,
               thresholds=None,
               top_k=None,
               class_id=None,
               name=None,
               dtype=None):
    super(PrecisionGrouped, self).__init__(name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.math.reduce_max(y_true, axis=1)
    y_pred = tf.math.reduce_max(y_pred, axis=1)
    return super().update_state(y_true, y_pred, sample_weight)

但是,当我运行代码时,它抱怨update_state方法应该返回张量。 但是我只是在调用父方法,该方法还会返回update_op

TypeError: To be compatible with tf.contrib.eager.defun, Python functions must return zero or more Tensors; in compilation of <function PrecisionGrouped.update_state at 0x1a384b1400>, found return value of type <class 'tensorflow.python.framework.ops.Operation'>, which is not a Tensor.

我要做的就是在输入中添加一个简单的预处理步骤。 tf.keras.metrics.Precision()正常工作,但PrecisionGrouped()无效。

1 个答案:

答案 0 :(得分:2)

您可以从update_state()方法中删除收益。

def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.math.reduce_max(y_true, axis=1)
    y_pred = tf.math.reduce_max(y_pred, axis=1)
    super().update_state(y_true, y_pred, sample_weight)
  

您可以从自定义指标中删除返回声明和组操作。不需要。由于TPU存在问题,内置指标有不同的要求。一旦解决此问题,我们还将从内置指标中删除update_state的返回值。

有关更多详细信息,请参考此GitHub issue