tf.gather_nd的用法

时间:2019-09-29 21:54:24

标签: tensorflow

假设您有一个3张量

data = np.reshape(np.arange(12), [2, 2, 3])
x = tf.constant(data)

将其视为由上一个索引索引的2x2矩阵,我想从第一矩阵中获取第一列,从第二矩阵中获取第二列,从第三矩阵中获取第二列。

如何使用tf.gather_nd做到这一点?

2 个答案:

答案 0 :(得分:0)

您首先需要生成所需的索引。

snippetManager.register([
    {
        "tabTrigger": "rett",
        "name": "rett",
        "content": "return true;"
    },
    {
        "tabTrigger": "retf",
        "name": "retf",
        "content": "return false;"
    },
    {
        "tabTrigger": "test_snippet",
        "name": "test_snippet",
        "content": "echo \"This is a test snippet\";\";"
    }
], "php")

答案 1 :(得分:0)

我在网上找到了以下教程,其中介绍了如何处理此类问题:https://geekyisawesome.blogspot.com/2018/05/fancy-indexing-in-tensorflow-getting.html

假设我们有一个4x3的矩阵

M = tf.constant(np.arange(12).reshape(4,3))

现在让我们说,您想要第一行的第三个元素,第二行的第二个元素,第三行的第一个元素和第四行的第二个元素。如本教程中所述,可以通过以下方式完成:

idx = tf.constant([2,1,0,1], tf.int32)
x = tf.gather_nd(M, tf.stack([tf.range(M.shape[0]), idx], axis=1))

但是如果M的行数未知,该怎么办? (并且idx为适当大小的整数张量)然后tf.range(M.shape [0])将引发错误。我该如何解决?