使用重复优先级将火炬张量插入heapq时出错

时间:2019-06-03 22:50:12

标签: python torch heapq

如何避免使用此代码中的RuntimeError: bool value of Tensor with more than one value is ambiguous

import torch
import heapq

h = []
heapq.heappush(h, (1, torch.Tensor([[1,2]])))
heapq.heappush(h, (1, torch.Tensor([[3,4]])))

之所以会这样,是因为当第一个元素相等时,元组之间的比较会比较第二个元素

1 个答案:

答案 0 :(得分:0)

当找到重复的优先级并且只需要为我的元素重新定义<运算符时,有必要防止heapq尝试比较元组的第二个元素。

import torch
import heapq

class HeapItem:
    def __init__(self, p, t):
        self.p = p
        self.t = t

    def __lt__(self, other):
        return self.p < other.p

h = []
heapq.heappush(h, HeapItem(1, torch.Tensor([[1,2]])))
heapq.heappush(h, HeapItem(1, torch.Tensor([[3,4]])))