我想在列表li
中找到最小张量的索引(通过某个键函数)。
所以我做了min
,之后做了li.index(min_el)
。我的MWE表明,某些张量不适用于index
。
import torch
li=[torch.ones(1,1), torch.zeros(2,2)]
li.index(li[0])
0
li.index(li[1])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../local/lib/python2.7/site-packages/torch/tensor.py", line 330, in __eq__
return self.eq(other)
RuntimeError: inconsistent tensor size at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:2679
我当然可以创建自己的索引函数,首先检查大小,然后检查元素。 E.g。
def index(list, element):
for i,el in enumerate(list):
if el.size() == element.size():
diff = el - element
if (1- diff.byte()).all():
return i
return -1
我只是想知道,为什么index
不起作用?
是不是有一种聪明的方法可以不用手工做到这一点我不知道?
答案 0 :(得分:1)
您可以使用enumerate
和对每个元组的第二个元素进行操作的键函数直接找到索引。例如,如果您的键函数比较每个张量的第一个元素,则可以使用
ix, _ = min(enumerate(li), key=lambda x: x[1][0, 0])
我认为第一个元素index
成功的原因是Python可能做的事情等同于x is value or x == value
,其中x
是列表中的当前元素。从value is value
开始,失败的相等比较永远不会发生。这就是为什么li.index(li[0])
有效但失败的原因:
y = torch.ones(1,1)
li.index(y)