此问题与this one相同,但适用于Tensorflow。
假设我有“行”的2D张量,并想从每一行中选择第i个元素并组成这些元素的结果列,在选择器张量中有i-s,如下所示
import tensorflow as tf
import numpy as np
rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
# gives
# array([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17],
# [18, 19, 20],
# [21, 22, 23],
# [24, 25, 26],
# [27, 28, 29]])
selector = tf.get_variable("selector", [10,1], dtype=tf.int8, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]]))
result_of_selection = ...
# should be
# array([[ 0],
# [ 4],
# [ 6],
# [11],
# [13],
# [15],
# [18],
# [23],
# [26],
# [28]])
我该怎么做?
更新
我这样写(感谢@Psidom)
import tensorflow as tf
import numpy as np
rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
# selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([0, 1, 0, 2, 1, 0, 0, 2, 2, 1], dtype=tf.int32))
# selector = tf.expand_dims(selector, axis=1)
selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]], dtype=tf.int32))
ordinals = tf.reshape(tf.range(rows.shape[0]), (-1,1))
#idx = tf.concat([selector, ordinals], axis=1)
idx = tf.stack([selector, ordinals], axis=-1)
result = tf.gather_nd(rows, idx)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
rows_value, result_value = sess.run([rows, result])
print("rows_value: " + str(rows_value))
print("selector_value: " + str(result_value))
它给了
rows_value: [[ 0. 1. 2.]
[ 3. 4. 5.]
[ 6. 7. 8.]
[ 9. 10. 11.]
[ 12. 13. 14.]
[ 15. 16. 17.]
[ 18. 19. 20.]
[ 21. 22. 23.]
[ 24. 25. 26.]
[ 27. 28. 29.]]
selector_value: [[ 0.]
[ 4.]
[ 2.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]]
即。不正确。
更新2
固定行
idx = tf.stack([ordinals, selector], axis=-1)
订单不正确。
答案 0 :(得分:2)
执行此操作的一种方法是通过堆叠可以使用tf.range
使用selector
创建的行索引来显式构建索引,然后使用tf.gather_nd
来收集项目:< / p>
rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
selector = tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]])
idx = tf.stack([tf.reshape(tf.range(rows.shape[0]), (-1,1)), selector], axis=-1)
with tf.Session() as sess:
print(sess.run(tf.gather_nd(rows, idx)))
#[[ 0.]
# [ 4.]
# [ 6.]
# [ 11.]
# [ 13.]
# [ 15.]
# [ 18.]
# [ 23.]
# [ 26.]
# [ 28.]]
这里idx
是原始张量中所有元素的实际索引:
with tf.Session() as sess:
print(idx.eval())
#[[[0 0]]
# [[1 1]]
# [[2 0]]
# [[3 2]]
# [[4 1]]
# [[5 0]]
# [[6 0]]
# [[7 2]]
# [[8 2]]
# [[9 1]]]
编辑:selector
作为变量:
selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]]))
idx = tf.stack([tf.reshape(tf.range(rows.shape[0]), (-1,1)), selector], axis=-1)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(tf.gather_nd(rows, idx)))
#[[ 0.]
# [ 4.]
# [ 6.]
# [ 11.]
# [ 13.]
# [ 15.]
# [ 18.]
# [ 23.]
# [ 26.]
# [ 28.]]