计算混淆矩阵的更快方法?

时间:2019-11-28 02:13:29

标签: python-3.x pytorch metrics confusion-matrix

我正在计算如下所示的混淆矩阵以进行图像语义分割,这是一种非常冗长的方法:

def confusion_matrix(preds, labels, conf_m, sample_size):
    preds = normalize(preds,0.9) # returns [0,1] tensor
    preds = preds.flatten()
    labels = labels.flatten()
    for i in range(len(preds)):
        if preds[i]==1 and labels[i]==1:
            conf_m[0,0] += 1/(len(preds)*sample_size) # TP
        elif preds[i]==1 and labels[i]==0:
            conf_m[0,1] += 1/(len(preds)*sample_size) # FP
        elif preds[i]==0 and labels[i]==0:
            conf_m[1,0] += 1/(len(preds)*sample_size) # TN
        elif preds[i]==0 and labels[i]==1:
            conf_m[1,1] += 1/(len(preds)*sample_size) # FN 
    return conf_m

在预测循环中:

conf_m = torch.zeros(2,2) # two classes (object or no-object)
for img,label in enumerate(data):
    ...
    out = Net(img)
    conf_m = confusion_matrix(out, label, len(data))
    ...

(在PyTorch中)是否有更快的方法来有效地计算用于图像语义分割的输入样本的混淆矩阵?

3 个答案:

答案 0 :(得分:2)

我使用这两个函数来计算混淆矩阵(如sklearn中所定义):

# rewrite sklearn method to torch
def confusion_matrix_1(y_true, y_pred):
    N = max(max(y_true), max(y_pred)) + 1
    y_true = torch.tensor(y_true, dtype=torch.long)
    y_pred = torch.tensor(y_pred, dtype=torch.long)
    return torch.sparse.LongTensor(
        torch.stack([y_true, y_pred]), 
        torch.ones_like(y_true, dtype=torch.long),
        torch.Size([N, N])).to_dense()

# weird trick with bincount
def confusion_matrix_2(y_true, y_pred):
    N = max(max(y_true), max(y_pred)) + 1
    y_true = torch.tensor(y_true, dtype=torch.long)
    y_pred = torch.tensor(y_pred, dtype=torch.long)
    y = N * y_true + y_pred
    y = torch.bincount(y)
    if len(y) < N * N:
        y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
    y = y.reshape(N, N)
    return T

y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]

confusion_matrix_1(y_true, y_pred)
# tensor([[2, 0, 0],
#         [0, 0, 1],
#         [1, 0, 2]])

在类数量较少的情况下,第二个功能更快。

%%timeit
confusion_matrix_1(y_true, y_pred)
# 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
confusion_matrix_2(y_true, y_pred)
# 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

答案 1 :(得分:0)

感谢Grigory Feldman的答案!我必须更改一些内容以配合实现。

对于将来的观察者,这是我的最终函数,该函数求和一批输入中每个混淆矩阵的百分比(用于训练或测试循环)

def confusion_matrix_2(y_true, y_pred, sample_sz, conf_m):
    y_pred = normalize(y_pred,0.9)
    obj = y_true[y_true==1]
    no_obj = y_true[y_true==0]
    N = torch.tensor(torch.max(torch.max(y_true), torch.max(y_pred)) + 1,dtype=torch.int)
    y_true = torch.tensor(y_true, dtype=torch.long)
    y_pred = torch.tensor(y_pred, dtype=torch.long)
    y = N * y_true + y_pred
    y = torch.bincount(y.flatten())
    if len(y) < N * N:
        y = torch.cat((y, torch.zeros(N * N - len(y), dtype=torch.long)))
    y = y.reshape(N.item(), N.item())
    y = y.float()
    conf_m[0,:] += y[0,:]/(len(no_obj)*sample_sz)
    conf_m[1,:] += y[1,:]/(len(obj)*sample_sz)
    return conf_m

...
conf_m = torch.zeros((2, 2),dtype=torch.float) # two classes (object or no-object)
for _, data in enumerate(dataloader):
    for img,label in enumerate(data):
        ...
        out = Net(img)
        conf_m = confusion_matrix(out, label, len(data))
        ...
    ...

答案 2 :(得分:0)

也感谢 Grigory Feldman 的回答!
Mr.O 和我用 numpy 做的。

# weird trick with bincount
def confusion_matrix_2_numpy(y_true, y_pred, N=None):
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1) 
    if (N is None):
        N = max(max(y_true), max(y_pred)) + 1
    y = N * y_true + y_pred
    y = np.bincount(y, minlength=N*N)
    y = y.reshape(N, N)
    return y

请试试这个。
当您使用已知的 class_num 时,它会更快。
我应该提到的一件事是,可能存在最大值与原始类数不匹配的情况。
例如,当批量较小并且每次迭代都对混淆矩阵进行积分时。