假设您有一个3张量
data = np.reshape(np.arange(12), [2, 2, 3])
x = tf.constant(data)
将其视为由上一个索引索引的2x2矩阵,我想从第一矩阵中获取第一列,从第二矩阵中获取第二列,从第三矩阵中获取第二列。
如何使用tf.gather_nd做到这一点?
答案 0 :(得分:0)
您首先需要生成所需的索引。
snippetManager.register([
{
"tabTrigger": "rett",
"name": "rett",
"content": "return true;"
},
{
"tabTrigger": "retf",
"name": "retf",
"content": "return false;"
},
{
"tabTrigger": "test_snippet",
"name": "test_snippet",
"content": "echo \"This is a test snippet\";\";"
}
], "php")
答案 1 :(得分:0)
我在网上找到了以下教程,其中介绍了如何处理此类问题:https://geekyisawesome.blogspot.com/2018/05/fancy-indexing-in-tensorflow-getting.html
假设我们有一个4x3的矩阵
M = tf.constant(np.arange(12).reshape(4,3))
现在让我们说,您想要第一行的第三个元素,第二行的第二个元素,第三行的第一个元素和第四行的第二个元素。如本教程中所述,可以通过以下方式完成:
idx = tf.constant([2,1,0,1], tf.int32)
x = tf.gather_nd(M, tf.stack([tf.range(M.shape[0]), idx], axis=1))
但是如果M的行数未知,该怎么办? (并且idx为适当大小的整数张量)然后tf.range(M.shape [0])将引发错误。我该如何解决?