如何在张量流中做多维切片?

时间:2016-12-06 01:46:31

标签: tensorflow

例如:

array = [[1, 2, 3], [4, 5, 6]]

slice = [[0, 0, 1], [0, 1, 2]]

output = [[1, 1, 2], [4, 5,6]]

我已经尝试了array[slice],但那并没有奏效。我也无法让tf.gathertf.gather_nd工作,尽管这些最初似乎是正确使用的功能。请注意,这些都是图形中的张量。

如何根据切片在数组中选择这些值?

2 个答案:

答案 0 :(得分:1)

您需要为slice张量添加维度tf.pack,然后我们就可以使用tf.gather_nd毫无问题。

import tensorflow as tf

tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
old_slice = tf.constant([[0, 0, 1], [0, 1, 2]])

# We need to add a dimension - we need a tensor of rank 2, 3, 2 instead of 2, 3
dims = tf.constant([[0, 0, 0], [1, 1, 1]])
new_slice = tf.pack([dims, old_slice], 2)
out = tf.gather_nd(tensor, new_slice)

如果我们运行以下代码:

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  run_tensor, run_slice, run_out = sess.run([tensor, new_slice, out])
  print 'Input tensor:'
  print run_tensor
  print 'Correct param for gather_nd:'
  print run_slice
  print 'Output:'
  print run_out

这应该给出正确的输出:

Input tensor:
[[1 2 3]
 [4 5 6]]
Correct param for gather_nd:
[[[0 0]
  [0 0]
  [0 1]]

 [[1 0]
  [1 1]
  [1 2]]]
Output:
[[1 1 2]
 [4 5 6]]

答案 1 :(得分:0)

一种更简单的计算结果的方法,也是更一般的性质,是直接利用 tf.gatherbatch_dims 参数:

>>> array = tf.constant([[1,2,3], [4,5,6]])
>>> slice = tf.constant([[0,0,1], [0,1,2]])
>>> output = tf.constant([[1,1,2], [4,5,6]])
>>> tf.gather(array, slice, batch_dims=1, axis=1)
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 1, 2],
       [4, 5, 6]], dtype=int32)>