我看到了包含函数tf.gather()的三元组丢失代码。该功能做什么?
我已经通过tensorflow的官方网站进行了定义,但是仍然无法获得它。
def margin_triplet_loss(y_true, y_pred, margin, batch_size):
anchor = tf.gather(y_pred, tf.range(0, batch_size, 3))
positive = tf.gather(y_pred, tf.range(1, batch_size, 3))
negative = tf.gather(y_pred, tf.range(2, batch_size, 3))
loss = K.maximum(margin
+ K.sum(K.square(anchor-positive), axis=1)
- K.sum(K.square(anchor-negative), axis=1),
0.0)
return K.mean(loss)
答案 0 :(得分:1)
tf.gather是对数组建立索引的函数。您收集由index参数指定的元素。对于张量流张量,这本来是不可能的。
tf.gather(y_pred,tf.range(0,batch_size,3))在numpy中等效于y_pred [0:batch_size:3],这意味着您从第一个元素开始返回每个第三个元素。