tensorflow tf.gather_nd用于处理未知维度

时间:2018-05-30 15:17:59

标签: tensorflow

实现下面代码的任何简单方法, 特别是处理未知维度,我想将此代码添加到损失函数中。谢谢。

result =[]
for i in range(0,x.shape[0]):
    tmp2 = tf.gather_nd(x[i], y[i])
    result.append(tmp2)
finalResult = tf.stack(result)

示例

x shape =(?,3,2)

y shape =(?,1)

x :
[[[ 0  1]
  [ 2  3]
  [ 4  5]]

 [[ 6  7]
  [ 8  9]
  [10 11]]

 [[12 13]
  [14 15]
  [16 17]]...]

y :
[[1]
 [0]
 [2]...]

finalResult :
[[ 2  3]
 [ 6  7]
 [16 17]...]

1 个答案:

答案 0 :(得分:0)

jdehesa的回复很有帮助。非常感谢。 必须添加第一个维度的索引进行查询。 (顺便说一句,我在损失函数中犯了一个错误。它必须是可区分的。 但这是另一个问题。)无论如何,再次感谢。