我想建立一个指标,以计算组级别的精度。
例如假设形状为(batch, 10, 1)
的LSTM输出,我想沿时间维度分组(按10个时间戳分组)并计算精度。
我创建了指标,该指标继承了精度:
class PrecisionGrouped(tf.keras.metrics.Precision):
def __init__(self,
thresholds=None,
top_k=None,
class_id=None,
name=None,
dtype=None):
super(PrecisionGrouped, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.math.reduce_max(y_true, axis=1)
y_pred = tf.math.reduce_max(y_pred, axis=1)
return super().update_state(y_true, y_pred, sample_weight)
但是,当我运行代码时,它抱怨update_state方法应该返回张量。
但是我只是在调用父方法,该方法还会返回update_op
。
TypeError: To be compatible with tf.contrib.eager.defun, Python functions must return zero or more Tensors; in compilation of <function PrecisionGrouped.update_state at 0x1a384b1400>, found return value of type <class 'tensorflow.python.framework.ops.Operation'>, which is not a Tensor.
我要做的就是在输入中添加一个简单的预处理步骤。 tf.keras.metrics.Precision()
正常工作,但PrecisionGrouped()
无效。
答案 0 :(得分:2)
您可以从update_state()
方法中删除收益。
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.math.reduce_max(y_true, axis=1)
y_pred = tf.math.reduce_max(y_pred, axis=1)
super().update_state(y_true, y_pred, sample_weight)
您可以从自定义指标中删除返回声明和组操作。不需要。由于TPU存在问题,内置指标有不同的要求。一旦解决此问题,我们还将从内置指标中删除update_state的返回值。
有关更多详细信息,请参考此GitHub issue。