通过火炬中的张量迭代

时间:2020-04-29 14:24:27

标签: python pytorch tensor

我有两个一维张量。一个是预测的向量,第二是标签的向量。我正在尝试编写一个循环来检查向量之间的逐元素差异。如果发现了这种差异,我想做另一种操作,为简单起见,我要打印(“发现差异”)。到目前为止,我想到了这一点,但遇到了一个错误:标量类型为Byte的预期对象,但参数#2'other'的标量类型为Float。我希望在这里为您提供帮助。也许有一些更有效的方法可以做到无循环。

def partition(lst, l, h):
    lst.append(float("inf"))
    pivot = lst[0]
    i, j = l+1, h
    while i < j:
        while lst[i] < pivot:
            i += 1
        while lst[j] > pivot:      
            j -= 1
        if i < j:
            lst[i] , lst[j] = lst[j], lst[i]
        else:
            lst = lst[1:i] + [pivot] + lst[i:]
    return lst[:-1], i

def quickSort(lst, l, h):
    if l < h-1:
        mid = (l + h)//2
        lst[l:h], mid = partition(lst[l:h], 0, h-l)
        quickSort(lst, l, mid)
        quickSort(lst, mid, h)
        lst1 = [10, 12, 8, 16, 2, 6, 3, 9, 5]
        quickSort(lst1, 0, 9)

1 个答案:

答案 0 :(得分:0)

您可以在pytorch中使用eq()函数来检查张量是否与元素相同。对于与标签元素相同的元素的每个索引,您将获得一个True

for label in predictions.round().eq(labels):
    for element in label:
        if element == False:
            print("Diff spotted!")