将torch.topk的dim参数合并到tf.nn.top_k中

时间:2018-09-01 08:59:29

标签: python tensorflow pytorch

Pytorch提供了torch.topk(input, k, dim=None, largest=True, sorted=True)函数,用于计算沿给定维度k的给定input张量的dim个最大元素。

我的形状为(16, 512, 4096)的张量,并且我以以下方式使用torch.topk-

# inputs.shape (16L, 512L, 4096L)
dist, idx = torch.topk(inputs, 64, dim=2, largest=False, sorted=False)
# dist.shape (16L, 512L, 64L), idx.shape (16L, 512L, 64L)

我发现以下类似的张量流实现-tf.nn.top_k(input, k=1, sorted=True, name=None)

我的问题是如何在dim=2中引入tf.nn.top_k参数,以实现与pytorch计算的形状相同的张量?

1 个答案:

答案 0 :(得分:1)

tf.nn.top_k适用于输入的最后一个维度。这意味着它应该像您的示例一样工作:

dist, idx = tf.nn.top_k(inputs, 64, sorted=False)

通常,您可以想象Tensorflow版本的工作方式类似于带有硬编码dim=-1(即最后一个维度)的Pytorch版本。

但是看起来您实际上需要k个最小元素。在这种情况下,我们可以做到

dist, idx = tf.nn.top_k(-1*inputs, 64, sorted=False)
dist = -1*dist

因此,我们采用输入中k个最大的值,它们是原始输入中k个最小的值。然后,我们将值的负值反转。