我有Tensorflow情况。我想使用Tensorflow操作获得某种类型的输出。
假设我们有一个张量变量:
"input" = (2, 4, 4)
[[[ 0.6036284 0.0281072 0.78739774 0.79748493]
[ 0.92121416 0.31211454 0.75201935 0.49418229]
[ 0.99500716 0.35610485 0.78246456 0.32932794]
[ 0.44941011 0.33340591 0.56897491 0.16929366]]
[[ 0.82108098 0.50557786 0.76569009 0.04855939]
[ 0.55340368 0.11384677 0.63739866 0.09481387]
[ 0.52711403 0.5621863 0.44211769 0.85780412]
[ 0.15423198 0.80663997 0.86868405 0.48221472]]]
我们有另一个张量变量,它包含要从a。
中提取的元素的索引"idx" = (2, 2)
[[2 0]
[2 0]]
我想形成一个任务,当我们切片"输入"使用" idx"然后得到的结果如下。请注意,切片基于第二维进行。
Output: (2, 2, 4)
[[[ 0.99500716 0.35610485 0.78246456 0.32932794]
[ 0.6036284 0.0281072 0.78739774 0.79748493]]
[[ 0.52711403 0.5621863 0.44211769 0.85780412]
[ 0.82108098 0.50557786 0.76569009 0.04855939]]]
我想使用Tensorflow实现类似的操作,其中"输入"和" idx"动态填充。
当我们明确提及" idx"时,我能想到的一种方式;是:
idx = [[[0,2],[0,0]], [[1,2],[1,0]]]
output = tf.gather_nd(input, idx)
但我不确定如何从动态填充的idx = [[2]构造idx = [[[0,2],[0,0]],[[1,2],[1,0]]] 0],[2 0]]
我尝试使用不同的组合使用tf.map_fn,但我仍然无法找到解决方案。
任何帮助将不胜感激..谢谢
答案 0 :(得分:1)
您可以通过以下方式构建完整索引:
#Use meshgrid to get [[0 0] [1 1]]
mesh = tf.meshgrid(tf.range(indices.shape[1]), tf.range(indices.shape[0]))[1]
#Stack mesh and the idx
full_indices = tf.stack([mesh, indices], axis=2)
#Output
# [[[0 2] [0 0]]
# [[1 2] [1 0]]]