Tensorflow:如何使用尺寸未知的另一个张量索引一个张量?

时间:2018-12-14 05:45:48

标签: python tensorflow

我有一个形状为params的张量(?, 70, 64)和另一个形状为indices的张量(?, 1)。我想使用第二张量索引到第一个张量的轴1,以得到形状为(?, 64)的结果。

我不知道如何去做。这是我尝试过的:

tf.gather(params, indices)           # returns a tensor of shape (?, 1, 70, 64)
tf.gather(params, indices, axis=1)   # returns a tensor of shape (?, ?, 1, 64)
tf.gather_nd(params, indices)        # returns a tensor of shape (?, 70, 64)

(我使用的是TensorFlow的旧版本,没有batch_gather。) 任何帮助将不胜感激。

谢谢!

1 个答案:

答案 0 :(得分:0)

您可以使用tf.stack将索引转换为形状(?, 2)的张量,第二维中的第一个数字为批号。然后,如果我正确理解了您的目标,那么将这些新索引与tf.gather_nd配合使用就可以为您提供所需的信息。

由于您的indices是形状为(?, 1)的张量,所以batch_gather会给您(?, 1, 64),这意味着从形状{{1} }。以下代码显示了两种方法给您相同的结果:

(?, 64)

根据评论“未知的第一维度”进行编辑

总体而言,最佳解决方案取决于特定的用例,并且将import numpy as np import tensorflow as tf params = tf.constant(np.arange(3*70*64).reshape(3, 70, 64)) init_indices = tf.constant([[2], [1], [0]]) indices = tf.stack( [tf.range(init_indices.shape[0]), tf.reshape(init_indices, [-1])], axis=1 ) output = tf.gather_nd(params, indices) batch_gather = tf.reshape(tf.batch_gather(params, init_indices), [params.shape[0], -1]) with tf.Session() as sess: print('tf.gather_nd') print(output.shape) print(sess.run(output)) print('batch_gather') print(batch_gather.shape) print(sess.run(batch_gather)) tf.gather_nd一起使用,关键是要获取批处理大小,即第一维。一种可能并非最佳的方法是使用tf.stack

tf.shape

要指出的一件事是因为批量大小未知,import numpy as np import tensorflow as tf params = tf.placeholder(shape=(None, 70, 64), dtype=tf.int32) init_indices = tf.placeholder(shape=(None, 1), dtype=tf.int32) indices = tf.stack( [tf.range(tf.shape(init_indices)[0]), tf.reshape(init_indices, [-1])], axis=1 ) output = tf.gather_nd(params, indices) batch_gather = tf.reshape(tf.batch_gather(params, init_indices), [tf.shape(params)[0], -1]) with tf.Session() as sess: print('tf.gather_nd') print(output.shape) print(sess.run( output, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64), init_indices: [[2], [1], [0]]} )) print('batch_gather') print(batch_gather.shape) print(sess.run( batch_gather, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64), init_indices: [[2], [1], [0]]} )) 给出的是print(batch_gather.shape)而不是(?, ?)