我正在尝试使用带有tensorflow后端的keras编写GAN生成器模型的代码。我希望生成器输出是精确坐标中图像值的向量(对于批处理中的每个图像,大小相同)。这些坐标也作为生成器的输入给出。
我尝试使用tf.gather_nd
作为函数来执行类似numpy的从精确坐标中提取值的操作。
img
是根据噪声图像生成的,形状为(=,?,28,28,1),
coordinates
是形状为(?,80,2)的输入张量,具有从生成的图像img
中提取的80个点,
vect
是一个输出向量,其大小应为(?,80),
在哪是批处理大小。
vect = Lambda(lambda x: tf.gather_nd(x, tf.cast(coordinates, 'int64')))(img)
最后,此函数的输出形状为(?,80,28,1)而不是(?,80)。
提取这些点怎么样?
答案 0 :(得分:0)
您可以使用tf.gather_nd
这样操作:
import tensorflow as tf
def extract_pixels(img, coords):
# Number of images and pixels
s = tf.shape(coords, out_type=coords.dtype)
n = s[0]
p = s[1]
# Make gather index
i = tf.range(n)
ii = tf.tile(i[:, tf.newaxis, tf.newaxis], [1, p, 1])
idx = tf.concat([ii, coords], axis=-1)
# Gather pixel values
pixels = tf.gather_nd(tf.squeeze(img, axis=-1), idx)
return pixels
# ...
vect = Lambda(lambda x: extract_pixels(x, tf.cast(coordinates, 'int64')))(img)