将预测转换为带有阈值的标签

时间:2019-01-08 14:16:59

标签: python numpy

假设我有以下numpy数组

import numpy as np

arr = np.array([[0.2, 0.8], [0.99, 0.01], [0.08, 0.92]])
arr
Out[57]: 
array([[0.2 , 0.8 ],
       [0.99, 0.01],
       [0.08, 0.92]])

如果我想将此输出转换为“类”(或每一行中最大价值的索引),则只需使用:

arr.argmax(axis=1)
Out[58]: array([1, 0, 1], dtype=int64)

问题是,我想限制某个阈值。对于示例,我们使用 0.9 。因此,不符合阈值约束的每一行都将返回标签-1。

上述示例的输出将为[-1, 0, 1](因为0.8和0.2都不大于0.9)。

最Python的方式是什么?希望(但不是必须),使用numpy

1 个答案:

答案 0 :(得分:2)

您可以使用np.where

UPDATE tuo
SET trasco_id_object = op.pratica
FROM trasco_utilizzi_oggetto tuo
     JOIN oggetti_pratica op ON tuo.oggetto = op.oggetto
     JOIN pratiche_tributo pt ON op.pratica = pt.pratica
     JOIN denunce_ici di ON pt.pratica = di.pratica
WHERE op.pratica = (SELECT MIN(sq.pratica)
                    FROM oggetti_pratica sq
                    WHERE sq.pratica = pt.pratica
                      AND sq.oggetto = op.oggetto);

详细信息

m = arr > 0.9 np.where(m.any(axis=1), m.argmax(axis=1), -1) array([-1, 0, 1]) 返回具有相同形状的(arr > 0.9),指示满足条件的位置:

ndarray

array([[False, False], [ True, False], [False, True]]) 返回m.argmax(axis=1)m的地方:

True
对于满足array([0, 0, 1]) 的那些行,

np.where将返回m.argmax(axis=1),因此至少一个元素大于阈值。这里m.any(axis=1)给出:

m.any(axis=1)

否则,array([False, True, True]) 将返回np.where