我有一个(?, 3, 2, 5)
形状的张量。我想提供一对索引,从该张量的第一维和第二维中进行选择,它们的形状为(3, 2)
。
如果我提供4对这样的对,我希望得到的形状为(?, 4, 5)
。我以为这就是batch_gather
的作用:在第一(批量)维上“广播”收集索引。但这不是它的作用:
import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))
indices = tf.constant([
[2, 1],
[2, 0],
[1, 1],
[0, 1]
], tf.int32)
tf.batch_gather(data, indices)
结果是<tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32>
,而不是我期望的形状。
如何在不显式索引批次(大小未知)的情况下做什么?
答案 0 :(得分:0)
使用tf.batch_gather
的{{1}}形状的前导尺寸应与tensor
张量的形状的前导尺寸相匹配。
indice
您宁愿使用import tensorflow as tf
data = tf.placeholder(tf.float32, (2, 3, 2, 5))
print(data.shape) // (2, 3, 2, 5)
# shape of indices, [2, 3]
indices = tf.constant([
[1, 1, 1],
[0, 0, 1]
])
print(tf.batch_gather(data, indices).shape) # (2, 3, 2, 5)
# if shape of indice was (2, 3, 1) the output would be 2, 3, 1, 5
,如下所示
tf.gather_nd
答案 1 :(得分:0)
我想避免transpose
和Python循环,我认为这可行。这是设置:
import numpy as np
import tensorflow as tf
shape = None, 3, 2, 5
data = tf.placeholder(tf.int32, shape)
idxs_list = [
[2, 1],
[2, 0],
[1, 1],
[0, 1]
]
idxs = tf.constant(idxs_list, tf.int32)
这使我们可以收集结果:
batch_size, num_idxs, num_channels = tf.shape(data)[0], tf.shape(idxs)[0], shape[-1]
batch_idxs = tf.math.floordiv(tf.range(0, batch_size * num_idxs), num_idxs)[:, None]
nd_idxs = tf.concat([batch_idxs, tf.tile(idxs, (batch_size, 1))], axis=1)
gathered = tf.reshape(tf.gather_nd(data, nd_idxs), (batch_size, num_idxs, num_channels))
当我们以4
的批处理量运行时,我们得到的结果形状为(4, 4, 5)
,即(batch_size, num_idxs, num_channels)
。
vals_shape = 4, *shape[1:]
vals = np.arange(int(np.prod(vals_shape))).reshape(vals_shape)
with tf.Session() as sess:
result = gathered.eval(feed_dict={data: vals})
与numpy
索引相关联:
x, y = zip(*idxs_list)
assert np.array_equal(result, vals[:, x, y])
本质上,gather_nd
希望在第一维中使用批处理索引,并且必须为每个索引对重复一次(即,[0, 0, 0, 0, 1, 1, 1, 1, 2, ...]
如果有4个索引对)。
由于似乎没有tf.repeat
,所以我使用了range
和floordiv
,然后concat
用所需的(x,y )索引(本身batch_size
次平铺)。