在多维数组上使用numpy.where

时间:2017-07-21 21:23:42

标签: python arrays numpy

我有一个2D数组,每一行代表一个分类器的输出,它将一些输入分类为3个类别(数组大小为1000 * 3):

0.3 0.3 0.3
0.3 0.3 1.0
1.0 0.3 0.3
0.3 0.3 0.3
0.3 1.0 0.3
...

我想获得分类器所有输入的列表"不确定"关于他们。 我定义了"不确定"因为没有类别高于0.8。

为了解决这个问题,我使用:

np.where(model1_preds.max(axis=1) < 0.8)

这很有效。

但是现在我有6个分类器(以相同的顺序分析了相同的输入),以及表示结果的数组6 * 1000 * 3

我想找到两件事:

  1. 至少有一个分类器的所有输入都是&#34;不确定&#34;关于。
  2. 所有分类器的所有输入都是&#34;不确定&#34;约。
  3. 我认为总体方向是这样的:

    np.stack(np.where(model_preds.max(axis=1) < 0.8) for model_preds in all_preds)
    

    但它不会起作用,因为python不知道我在for循环中的含义。

2 个答案:

答案 0 :(得分:3)

np.where

的替代方案
res_all_unsure = preds[:,np.amax(preds, axis=(0,2)) <= 0.8,:]
res_one_unsure = preds[:,preds.max(-1).min(0) <= 0.8,:]

答案 1 :(得分:2)

如果它已经是 6×1000×3矩阵preds ,您可以先np.transpose()将其转换为1000×6×3矩阵。

y = preds.transpose(1,0,2)  # preds is the input matrix, 6x1000x3

接下来我们可以将其转换为1000×6矩阵,对于每个实验和每个分类器,我们通过陈述来了解所有值是否小于0.8

y = np.all(y<0.8,axis=2)

最后,我们可以使用其他np.all()来验证分类器不确定的位置所有

all_classifiers_unsure = np.where(np.all(y,axis=1))  # all classifiers

任何分类器不确定的地方:

any_classifier_unsure = np.where(np.any(y,axis=1))   # any of the classifiers

我们可以写得更短:

experiment_classifier = np.all(preds.transpose(1,0,2) < 0.8,axis=2)
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))

虽然我非常有信心,但请通过检查一些指数(那些是真的和那些不正确的指数)进行验证。

修改

您仍然可以使用.max() < 0.8提议的方法,但使用axis=2

experiment_classifier = preds.transpose(1,0,2).max(axis=2) < 0.8
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))