我在张量流中有一个形状为(16, 512, 4096)
的张量,我想从张量中计算出最小的k
个元素。
请注意,我可以使用以下代码段在pytorch中获取它-
#inputs.shape (16L, 512L, 4096L)
dists, inputs_idx = torch.topk(inputs, 64, 2, largest=False, sorted=False)
#dists.shape (16L, 512L, 64L), inputs_idx.shape (16L, 512L, 64L)
请问有什么解决方法?