批量4D张量Tensorflow索引

时间:2017-01-06 12:37:16

标签: python tensorflow

鉴于

  • batch_images:4D张量形状(B, H, W, C)
  • x:3D张量形状(B, H, W)
  • y:3D张量形状(B, H, W)

目标

如何使用batch_imagesx坐标索引y以获得形状B, H, W, C的4D张量。也就是说,我希望获得每个批次,并为每对(x, y)获得一个形状C的张量。

在numpy中,这可以使用input_img[np.arange(B)[:,None,None], y, x]来实现,但我似乎无法使其在tensorflow中工作。

到目前为止我的尝试

def get_pixel_value(img, x, y):
    """
    Utility function to get pixel value for 
    coordinate vectors x and y from a  4D tensor image.
    """
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    C = tf.shape(img)[3]

    # flatten image
    img_flat = tf.reshape(img, [-1, C])

    # flatten idx
    idx_flat = (x*W) + y

    return tf.gather(img_flat, idx_flat)

返回不正确的形状张量(B, H, W)

1 个答案:

答案 0 :(得分:1)

应该可以通过展平张量来实现,但是在索引计算中必须考虑批量维度。 为此,您必须制作一个额外的虚拟批量索引张量,其形状与# Serve php scripts. - url: /(.+\.php)$ script: \1 x相同,始终包含当前批次的索引。 这基本上是你的numpy示例中的y,你的TensorFlow代码中缺少它。

您还可以使用tf.gather_nd来简化一些事情,它会为您进行索引计算。

以下是一个例子:

np.arange(B)