我有这种损失功能:
def bingo_loss(y_true, y_pred):
one = tf.ones([10])
y2 = tf.subtract(one, y_pred)
_, indices = tf.nn.top_k(y_pred, k = 3)
loss = tf.scatter_update(y_pred, indices, y2)
return loss
输出有10个节点。
损失选择前3个最大输出,其索引存储在indices
。
对于前3个值,损失为(1 - y_pred)。
否则,损失是(y_pred)。
这是通过scatter_update
实现的。
但我有一个模糊的错误:
File "Keras-NN-training2.py", line 104, in bingo_loss
loss = tf.scatter_update(y_pred, indices, y2)
File ".....state_ops.py", line 352, in scatter_update
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
AttributeError: 'Tensor' object has no attribute 'handle'
我的朋友在另一台机器上跑步说他得到了
TypeError: 'ScatterUpdate' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)
我已经放弃了这种方法 - 我宁愿编写一个新的层来实现我所需要的,而不仅仅是非常严格的损失函数。但知道解决这个问题的方法真好。谢谢!