numpy获取列索引,其中所有元素都大于阈值

时间:2017-07-25 15:01:54

标签: python numpy

我想找到一个numpy数组的列索引,其中列的所有元素都大于阈值。

例如,

 X = array([[ 0.16,  0.40,  0.61,  0.48,  0.20],
            [ 0.42,  0.79,  0.64,  0.54,  0.52],
            [ 0.64,  0.64,  0.24,  0.63,  0.43],
            [ 0.33,  0.54,  0.61,  0.43,  0.29],
            [ 0.25,  0.56,  0.42,  0.69,  0.62]])

在上述情况下,如果阈值为0.4,我的结果应为1,3。

3 个答案:

答案 0 :(得分:6)

您可以使用min

与每列的np.where进行比较
large = np.where(X.min(0) >= 0.4)[0]

答案 1 :(得分:0)

使用列表理解的通用解决方案

threshold = 0.4
rows_nb, col_nb = shape(X)
rows_above_threshold = [col for col in range(col_nb) \
    if all([X[row][col] >= threshold for row in range(rows_nb)])]

答案 2 :(得分:0)

x = array([[ 0.16,  0.40,  0.61,  0.48,  0.20],
        [ 0.42,  0.79,  0.64,  0.54,  0.52],
        [ 0.64,  0.64,  0.24,  0.63,  0.43],
        [ 0.33,  0.54,  0.61,  0.43,  0.29],
        [ 0.25,  0.56,  0.42,  0.69,  0.62]])

threshold = 0.3
size = numpy.shape(x)[0]
for it in range(size):
    y = x[it] > threshold
    print(y.all())

试试吧。