我有以下程序
import numpy as np
arr = np.random.randn(3,4)
print(arr)
regArr = (arr > 0.8)
print (regArr)
print (arr[ regArr].reshape(arr.shape))
输出:
[[ 0.37182134 1.4807685 0.11094223 0.34548185]
[ 0.14857641 -0.9159358 -0.37933393 -0.73946522]
[ 1.01842304 -0.06714827 -1.22557205 0.45600827]]
我正在寻找arr中的输出,其中应该存在大于0.8的值,而其他值则为零。
如上所述,我尝试使用bool掩蔽。但是我能够解决这个问题。请帮助
答案 0 :(得分:1)
我不确定您到底要实现什么,但这是我要过滤的。
arr = np.random.randn(3,4)
array([[-0.04790508, -0.71700005, 0.23204224, -0.36354634],
[ 0.48578236, 0.57983561, 0.79647091, -1.04972601],
[ 1.15067885, 0.98622772, -0.7004639 , -1.28243462]])
arr[arr < 0.8] = 0
array([[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ],
[1.15067885, 0.98622772, 0. , 0. ]])
感谢user3053452,我又添加了一种解决方案,该解决方案不会更改原始数据。
arr = np.random.randn(3,4)
array([[ 0.4297907 , 0.38100702, 0.30358291, -0.71137138],
[ 1.15180635, -1.21251676, 0.04333404, 1.81045931],
[ 0.17521058, -1.55604971, 1.1607159 , 0.23133528]])
new_arr = np.where(arr < 0.8, 0, arr)
array([[0. , 0. , 0. , 0. ],
[1.15180635, 0. , 0. , 1.81045931],
[0. , 0. , 1.1607159 , 0. ]])