我有一个2D数组,每一行代表一个分类器的输出,它将一些输入分类为3个类别(数组大小为1000 * 3
):
0.3 0.3 0.3
0.3 0.3 1.0
1.0 0.3 0.3
0.3 0.3 0.3
0.3 1.0 0.3
...
我想获得分类器所有输入的列表"不确定"关于他们。 我定义了"不确定"因为没有类别高于0.8。
为了解决这个问题,我使用:
np.where(model1_preds.max(axis=1) < 0.8)
这很有效。
但是现在我有6个分类器(以相同的顺序分析了相同的输入),以及表示结果的数组6 * 1000 * 3
。
我想找到两件事:
我认为总体方向是这样的:
np.stack(np.where(model_preds.max(axis=1) < 0.8) for model_preds in all_preds)
但它不会起作用,因为python不知道我在for循环中的含义。
答案 0 :(得分:3)
np.where
:
res_all_unsure = preds[:,np.amax(preds, axis=(0,2)) <= 0.8,:]
res_one_unsure = preds[:,preds.max(-1).min(0) <= 0.8,:]
答案 1 :(得分:2)
如果它已经是 6×1000×3矩阵preds
,您可以先np.transpose()
将其转换为1000×6×3矩阵。
y = preds.transpose(1,0,2) # preds is the input matrix, 6x1000x3
接下来我们可以将其转换为1000×6矩阵,对于每个实验和每个分类器,我们通过陈述来了解所有值是否小于0.8
:
y = np.all(y<0.8,axis=2)
最后,我们可以使用其他np.all()
来验证分类器不确定的位置所有:
all_classifiers_unsure = np.where(np.all(y,axis=1)) # all classifiers
或任何分类器不确定的地方:
any_classifier_unsure = np.where(np.any(y,axis=1)) # any of the classifiers
我们可以写得更短:
experiment_classifier = np.all(preds.transpose(1,0,2) < 0.8,axis=2)
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))
虽然我非常有信心,但请通过检查一些指数(那些是真的和那些不正确的指数)进行验证。
修改强>
您仍然可以使用.max() < 0.8
提议的方法,但使用axis=2
:
experiment_classifier = preds.transpose(1,0,2).max(axis=2) < 0.8
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))