我有一个形状为params
的张量(?, 70, 64)
和另一个形状为indices
的张量(?, 1)
。我想使用第二张量索引到第一个张量的轴1,以得到形状为(?, 64)
的结果。
我不知道如何去做。这是我尝试过的:
tf.gather(params, indices) # returns a tensor of shape (?, 1, 70, 64)
tf.gather(params, indices, axis=1) # returns a tensor of shape (?, ?, 1, 64)
tf.gather_nd(params, indices) # returns a tensor of shape (?, 70, 64)
(我使用的是TensorFlow的旧版本,没有batch_gather
。)
任何帮助将不胜感激。
谢谢!
答案 0 :(得分:0)
您可以使用tf.stack
将索引转换为形状(?, 2)
的张量,第二维中的第一个数字为批号。然后,如果我正确理解了您的目标,那么将这些新索引与tf.gather_nd
配合使用就可以为您提供所需的信息。
由于您的indices
是形状为(?, 1)
的张量,所以batch_gather
会给您(?, 1, 64)
,这意味着从形状{{1} }。以下代码显示了两种方法给您相同的结果:
(?, 64)
总体而言,最佳解决方案取决于特定的用例,并且将import numpy as np
import tensorflow as tf
params = tf.constant(np.arange(3*70*64).reshape(3, 70, 64))
init_indices = tf.constant([[2], [1], [0]])
indices = tf.stack(
[tf.range(init_indices.shape[0]), tf.reshape(init_indices, [-1])],
axis=1
)
output = tf.gather_nd(params, indices)
batch_gather = tf.reshape(tf.batch_gather(params, init_indices),
[params.shape[0], -1])
with tf.Session() as sess:
print('tf.gather_nd')
print(output.shape)
print(sess.run(output))
print('batch_gather')
print(batch_gather.shape)
print(sess.run(batch_gather))
与tf.gather_nd
一起使用,关键是要获取批处理大小,即第一维。一种可能并非最佳的方法是使用tf.stack
:
tf.shape
要指出的一件事是因为批量大小未知,import numpy as np
import tensorflow as tf
params = tf.placeholder(shape=(None, 70, 64), dtype=tf.int32)
init_indices = tf.placeholder(shape=(None, 1), dtype=tf.int32)
indices = tf.stack(
[tf.range(tf.shape(init_indices)[0]), tf.reshape(init_indices, [-1])],
axis=1
)
output = tf.gather_nd(params, indices)
batch_gather = tf.reshape(tf.batch_gather(params, init_indices),
[tf.shape(params)[0], -1])
with tf.Session() as sess:
print('tf.gather_nd')
print(output.shape)
print(sess.run(
output, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64),
init_indices: [[2], [1], [0]]}
))
print('batch_gather')
print(batch_gather.shape)
print(sess.run(
batch_gather, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64),
init_indices: [[2], [1], [0]]}
))
给出的是print(batch_gather.shape)
而不是(?, ?)
。