如何动态切片张量变量

时间:2018-05-25 07:14:36

标签: python tensorflow

我有Tensorflow情况。我想使用Tensorflow操作获得某种类型的输出。

假设我们有一个张量变量:

"input" = (2, 4, 4)
[[[ 0.6036284   0.0281072   0.78739774  0.79748493]
  [ 0.92121416  0.31211454  0.75201935  0.49418229]
  [ 0.99500716  0.35610485  0.78246456  0.32932794]
  [ 0.44941011  0.33340591  0.56897491  0.16929366]]

 [[ 0.82108098  0.50557786  0.76569009  0.04855939]
  [ 0.55340368  0.11384677  0.63739866  0.09481387]
  [ 0.52711403  0.5621863   0.44211769  0.85780412]
  [ 0.15423198  0.80663997  0.86868405  0.48221472]]]

我们有另一个张量变量,它包含要从a。

中提取的元素的索引
"idx" = (2, 2)
[[2 0]
 [2 0]]

我想形成一个任务,当我们切片"输入"使用" idx"然后得到的结果如下。请注意,切片基于第二维进行。

Output: (2, 2, 4)
 [[[ 0.99500716  0.35610485  0.78246456  0.32932794]
  [ 0.6036284   0.0281072   0.78739774  0.79748493]]

 [[ 0.52711403  0.5621863   0.44211769  0.85780412]
  [ 0.82108098  0.50557786  0.76569009  0.04855939]]]

我想使用Tensorflow实现类似的操作,其中"输入"和" idx"动态填充。

当我们明确提及" idx"时,我能想到的一种方式;是:

idx = [[[0,2],[0,0]], [[1,2],[1,0]]]
output = tf.gather_nd(input, idx)

但我不确定如何从动态填充的idx = [[2]构造idx = [[[0,2],[0,0]],[[1,2],[1,0]]] 0],[2 0]]

我尝试使用不同的组合使用tf.map_fn,但我仍然无法找到解决方案。

任何帮助将不胜感激..谢谢

1 个答案:

答案 0 :(得分:1)

您可以通过以下方式构建完整索引:

#Use meshgrid to get [[0 0] [1 1]]
mesh = tf.meshgrid(tf.range(indices.shape[1]), tf.range(indices.shape[0]))[1]

#Stack mesh and the idx
full_indices = tf.stack([mesh, indices], axis=2)
#Output
# [[[0 2] [0 0]]
#  [[1 2] [1 0]]]