tensorflow.gather()到底是做什么的?

时间:2019-06-28 07:41:59

标签: numpy tensorflow deep-learning

我看到了包含函数tf.gather()的三元组丢失代码。该功能做什么?

我已经通过tensorflow的官方网站进行了定义,但是仍然无法获得它。

def margin_triplet_loss(y_true, y_pred, margin, batch_size):
    anchor = tf.gather(y_pred, tf.range(0, batch_size, 3))
    positive = tf.gather(y_pred, tf.range(1, batch_size, 3))
    negative = tf.gather(y_pred, tf.range(2, batch_size, 3))

    loss = K.maximum(margin
                 + K.sum(K.square(anchor-positive), axis=1)
                 - K.sum(K.square(anchor-negative), axis=1),
                 0.0)
    return K.mean(loss)

1 个答案:

答案 0 :(得分:1)

tf.gather是对数组建立索引的函数。您收集由index参数指定的元素。对于张量流张量,这本来是不可能的。

tf.gather(y_pred,tf.range(0,batch_size,3))在numpy中等效于y_pred [0:batch_size:3],这意味着您从第一个元素开始返回每个第三个元素。