我有一个训练有素的10类机器学习模型,类似于MNIST数字分类器。我想确定每个数字类别正确的频率,如果犯了一个错误,那么它会混淆哪个类别。我想用它来创建一个混淆矩阵。
无论如何,每次通过的输入都是来自验证集的一批图片(形状(32,3,224,224),其中3、224和224是图片尺寸,而32是批尺寸)和标签(形状32 ,1)个与这些图片匹配的类别编号。模型输出为(形状32,1),并列出模型认为最匹配的类编号。通过比较标签和输出张量,我可以轻松找到多少个匹配项,但是我很难告诉如何错过了错误分类。这是主要验证循环的片段
# Main validation loop
valid_accuracy = 0.0
model.eval()
device = 'cuda'
raw_counts = torch.zeros((11,11)) # leave room for totals in the last row and column
with torch.no_grad():
for inputs, labels in validloader:
# Run each image through the network to get log probabilities of each class
inputs, labels = inputs.to(device), labels.to(device)
logps = model.forward(inputs)
# Calculate accuracy
ps = torch.exp(logps) # 32 X 10: probability of each label in every case
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
valid_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
# Count confusions
raw_counts[labels[:],top_class[:,0]] += 1 # <<<<---- This is the problem!
# Accumulate letter-by-letter certainties
for i in range(ll):
sum_ps[i,] += sum(ps[labels==i])
letter_counts[i] += len(ps[labels==i])
# Print validation accuracy
valid_accuracy = valid_accuracy/len(validloader)
print(f"Validation accuracy: {valid_accuracy:.3f}")
问题是我尝试计算混淆的地方。只有十个班级,每批32个,我保证有重复的标签。但是说raw_counts[labels[:],top_class[:,0]] += 1
只会raw_counts
矩阵的每一行增加一个。例如,在调试器中此行之前:
(Pdb) top_class[:,0]
tensor([5, 9, 5, 0, 2, 3, 3, 8, 2, 9, 6, 3, 0, 3, 1, 3, 3, 4, 0, 1, 5, 2, 8, 4,
5, 3, 6, 5, 0, 3, 2, 1], device='cuda:0')
(Pdb) labels[:]
tensor([5, 9, 5, 0, 2, 3, 3, 8, 2, 9, 6, 3, 0, 3, 1, 3, 3, 4, 0, 1, 5, 2, 8, 4,
5, 3, 8, 5, 0, 3, 2, 1], device='cuda:0')
(Pdb) raw_counts
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
一切都如预期。但是执行该行之后:
(Pdb) n
> /home/model.py(219)main()
-> for i in range(10):
(Pdb) raw_counts
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
即使这批中的五个标签为5,并且模型正常,它们仍为raw_counts[4,4] == 1
。有什么pythonic方法可以计算所有五个正确答案,而不会造成循环等混乱?
答案 0 :(得分:-1)
没有足够的声誉来发表评论,所以我将复制my answer。
我使用这两个函数来计算混淆矩阵(如sklearn中所定义):
# rewrite sklearn method to torch
def confusion_matrix_1(y_true, y_pred):
N = max(max(y_true), max(y_pred)) + 1
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
return torch.sparse.LongTensor(
torch.stack([y_true, y_pred]),
torch.ones_like(y_true, dtype=torch.long),
torch.Size([N, N])).to_dense()
# weird trick with bincount
def confusion_matrix_2(y_true, y_pred):
N = max(max(y_true), max(y_pred)) + 1
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
y = N * y_true + y_pred
y = torch.bincount(y)
if len(y) < N * N:
y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
y = y.reshape(N, N)
return T
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix_1(y_true, y_pred)
# tensor([[2, 0, 0],
# [0, 0, 1],
# [1, 0, 2]])
在类数量较少的情况下,第二个功能更快。
%%timeit
confusion_matrix_1(y_true, y_pred)
# 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
confusion_matrix_2(y_true, y_pred)
# 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)