假设我有以下numpy数组
import numpy as np
arr = np.array([[0.2, 0.8], [0.99, 0.01], [0.08, 0.92]])
arr
Out[57]:
array([[0.2 , 0.8 ],
[0.99, 0.01],
[0.08, 0.92]])
如果我想将此输出转换为“类”(或每一行中最大价值的索引),则只需使用:
arr.argmax(axis=1)
Out[58]: array([1, 0, 1], dtype=int64)
问题是,我想限制某个阈值。对于示例,我们使用 0.9 。因此,不符合阈值约束的每一行都将返回标签-1。
上述示例的输出将为[-1, 0, 1]
(因为0.8和0.2都不大于0.9)。
最Python的方式是什么?希望(但不是必须),使用numpy
。
答案 0 :(得分:2)
您可以使用np.where
:
UPDATE tuo
SET trasco_id_object = op.pratica
FROM trasco_utilizzi_oggetto tuo
JOIN oggetti_pratica op ON tuo.oggetto = op.oggetto
JOIN pratiche_tributo pt ON op.pratica = pt.pratica
JOIN denunce_ici di ON pt.pratica = di.pratica
WHERE op.pratica = (SELECT MIN(sq.pratica)
FROM oggetti_pratica sq
WHERE sq.pratica = pt.pratica
AND sq.oggetto = op.oggetto);
详细信息
m = arr > 0.9
np.where(m.any(axis=1), m.argmax(axis=1), -1)
array([-1, 0, 1])
返回具有相同形状的(arr > 0.9)
,指示满足条件的位置:
ndarray
array([[False, False],
[ True, False],
[False, True]])
返回m.argmax(axis=1)
是m
的地方:
True
对于满足array([0, 0, 1])
的那些行, np.where
将返回m.argmax(axis=1)
,因此至少一个元素大于阈值。这里m.any(axis=1)
给出:
m.any(axis=1)
否则,array([False, True, True])
将返回np.where