如何批量使用tf.gather?

时间:2019-04-09 16:28:56

标签: python-3.x tensorflow keras

我有一个A = 10x1000张量和一个B = 10x1000索引张量。张量B的值在0-999之间,用于从A收集值(B[0,:]A[0,:]收集,B[1,:]A[1,:]收集,等等。) / p>

但是,如果我使用tf.gather(A, B),当我期望返回(10, 1000, 1000)张量时,就会得到形状为10x1000的数组。有什么想法可以解决这个问题吗?

编辑

比方说A= [[1, 2, 3],[4,5,6]]B = [[0, 1, 1],[2,1,0]]我想要的是能够使用对应的B采样A。这应该得到C = [[1, 2, 2],[6,5,4]]

2 个答案:

答案 0 :(得分:0)

  1. 张量的尺寸是已知的。

首先,我们沿第一个维度“解栈”参数和索引(分别为AB)。然后我们应用tf.gather(),使A的行与B的行相对应。最后,我们将结果堆叠在一起。

import tensorflow as tf
import numpy as np

def custom_gather(a, b):
    unstacked_a = tf.unstack(a, axis=0)
    unstacked_b = tf.unstack(b, axis=0)
    gathered = [tf.gather(x, y) for x, y in zip(unstacked_a, unstacked_b)]
    return tf.stack(gathered, axis=0)

a = tf.convert_to_tensor(np.array([[1, 2, 3], [4, 5, 6]]), tf.float32)
b = tf.convert_to_tensor(np.array([[0, 1, 1], [2, 1, 0]]), dtype=tf.int32)

gathered = custom_gather(a, b)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gathered))
# [[1. 2. 2.]
#  [6. 5. 4.]]

对于形状为1000x10的初始情况,我们得到:

a = tf.convert_to_tensor(np.random.normal(size=(10, 1000)), tf.float32)
b = tf.convert_to_tensor(np.random.randint(low=0, high=999, size=(10, 1000)), dtype=tf.int32)
gathered = custom_gather(a, b)
print(gathered.get_shape().as_list()) # [10, 1000]

更新

  1. 第一个维度是未知的(即None

仅当预先知道第一维时,先前的解决方案才有效。如果尺寸未知,我们可以按以下方法解决:

  • 我们将两个张量堆叠在一起,以使两个张量的行堆叠在一起:
# A = [[1, 2, 3], [4, 5, 6]]        [[[1 2 3]
#                            --->     [0 1 1]]
#                                    [[4 5 6]
# B = [[0, 1, 1], [2, 1, 0]]          [2 1 0]]]
  • 我们遍历此堆叠张量(由AB的堆叠行组成)的元素,然后使用tf.map_fn()函数应用tf.gather()

  • 我们堆叠使用tf.stack()

  • 获得的元素
import tensorflow as tf
import numpy as np

def custom_gather_v2(a, b):
    def apply_gather(x):
        return tf.gather(x[0], tf.cast(x[1], tf.int32))
    a = tf.cast(a, dtype=tf.float32)
    b = tf.cast(b, dtype=tf.float32)
    stacked = tf.stack([a, b], axis=1)
    gathered = tf.map_fn(apply_gather, stacked)
    return tf.stack(gathered, axis=0)

a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
b = np.array([[0, 1, 1], [2, 1, 0]], dtype=np.int32)

x = tf.placeholder(tf.float32, shape=(None, 3))
y = tf.placeholder(tf.int32, shape=(None, 3))

gathered = custom_gather_v2(x, y)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gathered, feed_dict={x:a, y:b}))
# [[1. 2. 2.]
#  [6. 5. 4.]]

答案 1 :(得分:0)

if(line.compare("end") == 0 || line.compare("END") == 0) { break; } tf.gather一起使用:

batch_dims=-1