如何在列表中找到张量的索引?

时间:2018-01-30 17:58:28

标签: python pytorch

我想在列表li中找到最小张量的索引(通过某个键函数)。 所以我做了min,之后做了li.index(min_el)。我的MWE表明,某些张量不适用于index

import torch
li=[torch.ones(1,1), torch.zeros(2,2)]
li.index(li[0])
0
li.index(li[1])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../local/lib/python2.7/site-packages/torch/tensor.py", line 330, in __eq__
    return self.eq(other)
RuntimeError: inconsistent tensor size at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:2679

我当然可以创建自己的索引函数,首先检查大小,然后检查元素。 E.g。

def index(list, element):
    for i,el in enumerate(list):
        if el.size() == element.size():
            diff = el - element
            if (1- diff.byte()).all():
               return i
    return -1

我只是想知道,为什么index不起作用? 是不是有一种聪明的方法可以不用手工做到这一点我不知道?

1 个答案:

答案 0 :(得分:1)

您可以使用enumerate和对每个元组的第二个元素进行操作的键函数直接找到索引。例如,如果您的键函数比较每个张量的第一个元素,则可以使用

ix, _ = min(enumerate(li), key=lambda x: x[1][0, 0])

我认为第一个元素index成功的原因是Python可能做的事情等同于x is value or x == value,其中x是列表中的当前元素。从value is value开始,失败的相等比较永远不会发生。这就是为什么li.index(li[0])有效但失败的原因:

y = torch.ones(1,1)
li.index(y)