TensorFlow,批量索引(第一维)和排序

时间:2018-05-30 12:45:04

标签: python sorting tensorflow indexing keras

我有一个形状为(?,368,5)的参数张量,以及一个形状为(?,368)的查询张量。查询张量存储用于对第一张量进行排序的索引。

所需的输出形状为:(?,368,5)。由于我需要它用于神经网络中的损失函数,因此使用的操作应该保持可微。此外,在运行时,第一个轴?的大小对应于批量大小。

到目前为止,我尝试了tf.gathertf.gather_nd tf.gather(params,query)会产生形状为(?,368,368,5)的张量。

查询张量通过执行:

来实现
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices

总的来说,我尝试通过第三轴上的第一个元素(对于倒角距离的种类)对参数张量进行排序。最后要提到的是,我使用Keras框架。

1 个答案:

答案 0 :(得分:2)

您需要将第一维的索引添加到query,以便将其与tf.gather_nd一起使用。这是一种方法:

import tensorflow as tf
import numpy as np

np.random.seed(100)

with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.placeholder(tf.float32, [None, 368, 5])
    query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
    n = tf.shape(params)[0]
    # Make tensor of indices for the first dimension
    ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
    # Stack indices
    idx = tf.stack([ii, query], axis=-1)
    # Gather reordered tensor
    result = tf.gather_nd(params, idx)
    # Test
    out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
    # Check the order is correct
    print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
    # True