如何在张量流中提取图像的精确坐标矢量?

时间:2019-04-29 18:25:16

标签: python tensorflow keras

我正在尝试使用带有tensorflow后端的keras编写GAN生成器模型的代码。我希望生成器输出是精确坐标中图像值的向量(对于批处理中的每个图像,大小相同)。这些坐标也作为生成器的输入给出。

我尝试使用tf.gather_nd作为函数来执行类似numpy的从精确坐标中提取值的操作。

img是根据噪声图像生成的,形状为(=,?,28,28,1),

coordinates是形状为(?,80,2)的输入张量,具有从生成的图像img中提取的80个点,

vect是一个输出向量,其大小应为(?,80), 在哪是批处理大小。

vect = Lambda(lambda x: tf.gather_nd(x, tf.cast(coordinates, 'int64')))(img)

最后,此函数的输出形状为(?,80,28,1)而不是(?,80)。

提取这些点怎么样?

1 个答案:

答案 0 :(得分:0)

您可以使用tf.gather_nd这样操作:

import tensorflow as tf

def extract_pixels(img, coords):
    # Number of images and pixels
    s = tf.shape(coords, out_type=coords.dtype)
    n = s[0]
    p = s[1]
    # Make gather index
    i = tf.range(n)
    ii = tf.tile(i[:, tf.newaxis, tf.newaxis], [1, p, 1])
    idx = tf.concat([ii, coords], axis=-1)
    # Gather pixel values
    pixels = tf.gather_nd(tf.squeeze(img, axis=-1), idx)
    return pixels

# ...
vect = Lambda(lambda x: extract_pixels(x, tf.cast(coordinates, 'int64')))(img)