实现下面代码的任何简单方法, 特别是处理未知维度,我想将此代码添加到损失函数中。谢谢。
result =[]
for i in range(0,x.shape[0]):
tmp2 = tf.gather_nd(x[i], y[i])
result.append(tmp2)
finalResult = tf.stack(result)
示例
x shape =(?,3,2)
y shape =(?,1)
x :
[[[ 0 1]
[ 2 3]
[ 4 5]]
[[ 6 7]
[ 8 9]
[10 11]]
[[12 13]
[14 15]
[16 17]]...]
y :
[[1]
[0]
[2]...]
finalResult :
[[ 2 3]
[ 6 7]
[16 17]...]
答案 0 :(得分:0)
jdehesa的回复很有帮助。非常感谢。 必须添加第一个维度的索引进行查询。 (顺便说一句,我在损失函数中犯了一个错误。它必须是可区分的。 但这是另一个问题。)无论如何,再次感谢。