如何从张量流中的张量中的每一行中选择不同的列?

时间:2017-09-28 13:01:12

标签: python tensorflow

此问题与this one相同,但适用于Tensorflow。

假设我有“行”的2D张量,并想从每一行中选择第i个元素并组成这些元素的结果列,在选择器张量中有i-s,如下所示

import tensorflow as tf
import numpy as np

rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
# gives
# array([[ 0,  1,  2],
#        [ 3,  4,  5],
#        [ 6,  7,  8],
#        [ 9, 10, 11],
#        [12, 13, 14],
#        [15, 16, 17],
#        [18, 19, 20],
#        [21, 22, 23],
#        [24, 25, 26],
#        [27, 28, 29]])


selector = tf.get_variable("selector", [10,1], dtype=tf.int8, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]]))

result_of_selection = ...

# should be
# array([[ 0],
#        [ 4],
#        [ 6],
#        [11],
#        [13],
#        [15],
#        [18],
#        [23],
#        [26],
#        [28]])

我该怎么做?

更新

我这样写(感谢@Psidom)

import tensorflow as tf
import numpy as np

rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)

# selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([0, 1, 0, 2, 1, 0, 0, 2, 2, 1], dtype=tf.int32))
# selector = tf.expand_dims(selector, axis=1)
selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]], dtype=tf.int32))

ordinals = tf.reshape(tf.range(rows.shape[0]), (-1,1))

#idx = tf.concat([selector, ordinals], axis=1)
idx = tf.stack([selector, ordinals], axis=-1)

result = tf.gather_nd(rows, idx)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    rows_value, result_value = sess.run([rows, result])
    print("rows_value: " + str(rows_value))
    print("selector_value: " + str(result_value))

它给了

rows_value: [[  0.   1.   2.]
 [  3.   4.   5.]
 [  6.   7.   8.]
 [  9.  10.  11.]
 [ 12.  13.  14.]
 [ 15.  16.  17.]
 [ 18.  19.  20.]
 [ 21.  22.  23.]
 [ 24.  25.  26.]
 [ 27.  28.  29.]]
selector_value: [[ 0.]
 [ 4.]
 [ 2.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 0.]
 [ 0.]]

即。不正确。

更新2

固定行

idx = tf.stack([ordinals, selector], axis=-1)

订单不正确。

1 个答案:

答案 0 :(得分:2)

执行此操作的一种方法是通过堆叠可以使用tf.range使用selector创建的行索引来显式构建索引,然后使用tf.gather_nd来收集项目:< / p>

rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
selector = tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]])

idx = tf.stack([tf.reshape(tf.range(rows.shape[0]), (-1,1)), selector], axis=-1)

with tf.Session() as sess:
    print(sess.run(tf.gather_nd(rows, idx)))

#[[  0.]
# [  4.]
# [  6.]
# [ 11.]
# [ 13.]
# [ 15.]
# [ 18.]
# [ 23.]
# [ 26.]
# [ 28.]]

这里idx是原始张量中所有元素的实际索引:

with tf.Session() as sess:
    print(idx.eval())
#[[[0 0]]

# [[1 1]]

# [[2 0]]

# [[3 2]]

# [[4 1]]

# [[5 0]]

# [[6 0]]

# [[7 2]]

# [[8 2]]

# [[9 1]]]

编辑selector作为变量:

selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]]))
idx = tf.stack([tf.reshape(tf.range(rows.shape[0]), (-1,1)), selector], axis=-1)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(tf.gather_nd(rows, idx)))

#[[  0.]
# [  4.]
# [  6.]
# [ 11.]
# [ 13.]
# [ 15.]
# [ 18.]
# [ 23.]
# [ 26.]
# [ 28.]]