
时间:2019-01-29 16:52:17

标签: python tensorflow

我有一个(?, 3, 2, 5)形状的张量。我想提供一对索引,从该张量的第一维和第二维中进行选择,它们的形状为(3, 2)

如果我提供4对这样的对,我希望得到的形状为(?, 4, 5)。我以为这就是batch_gather的作用:在第一(批量)维上“广播”收集索引。但这不是它的作用:

import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))

indices = tf.constant([
    [2, 1],
    [2, 0],
    [1, 1],
    [0, 1]
], tf.int32)

tf.batch_gather(data, indices)

结果是<tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32>,而不是我期望的形状。


2 个答案:

答案 0 :(得分:0)



您宁愿使用import tensorflow as tf data = tf.placeholder(tf.float32, (2, 3, 2, 5)) print(data.shape) // (2, 3, 2, 5) # shape of indices, [2, 3] indices = tf.constant([ [1, 1, 1], [0, 0, 1] ]) print(tf.batch_gather(data, indices).shape) # (2, 3, 2, 5) # if shape of indice was (2, 3, 1) the output would be 2, 3, 1, 5 ,如下所示


答案 1 :(得分:0)


import numpy as np
import tensorflow as tf

shape = None, 3, 2, 5
data = tf.placeholder(tf.int32, shape)
idxs_list = [
    [2, 1],
    [2, 0],
    [1, 1],
    [0, 1]
idxs = tf.constant(idxs_list, tf.int32)


batch_size, num_idxs, num_channels = tf.shape(data)[0], tf.shape(idxs)[0], shape[-1]

batch_idxs = tf.math.floordiv(tf.range(0, batch_size * num_idxs), num_idxs)[:, None]
nd_idxs = tf.concat([batch_idxs, tf.tile(idxs, (batch_size, 1))], axis=1)

gathered = tf.reshape(tf.gather_nd(data, nd_idxs), (batch_size, num_idxs, num_channels))

当我们以4的批处理量运行时,我们得到的结果形状为(4, 4, 5),即(batch_size, num_idxs, num_channels)

vals_shape = 4, *shape[1:]
vals = np.arange(int(np.prod(vals_shape))).reshape(vals_shape)

with tf.Session() as sess:
    result = gathered.eval(feed_dict={data: vals})


x, y = zip(*idxs_list)
assert np.array_equal(result, vals[:, x, y])

本质上,gather_nd希望在第一维中使用批处理索引,并且必须为每个索引对重复一次(即,[0, 0, 0, 0, 1, 1, 1, 1, 2, ...]如果有4个索引对)。

由于似乎没有tf.repeat,所以我使用了rangefloordiv,然后concat用所需的(x,y )索引(本身batch_size次平铺)。