数组中的条件选择

时间:2018-08-04 08:57:32

标签: python arrays numpy

我有以下带有数组的列表:

[array([10,  1,  7,  3]),
 array([ 0, 14, 12, 13]),
 array([ 3, 10,  7,  8]),
 array([7, 5]),
 array([ 5, 12,  3]),
 array([14,  8, 10])]

我想要将行标记为“ 1”或“ 0”,条件是该行是否匹配“ 10”和“ 7”或“ 10”和“ 3”。

np.where(output== 10 & output == 7 ) | (output == 10 & output == 3 ) | (output == 10 & output == 8 ), 1, 0)

返回

array(0)

进入数组数组的正确语法是什么?

预期输出:

[ 1, 0, 1, 0, 0, 1 ]

注意: 什么是output?在Scikit中训练CountVectorizer / LDA主题分类器后,以下脚本将主题概率分配给新文档。然后,将阈值高于0.2的主题存储在数组中。

def sortthreshold(x, thresh):
    idx = np.arange(x.size)[x > thresh]
    return idx[np.argsort(x[idx])]

output = []
for x in newdoc:
    y = lda.transform(bowvectorizer.transform([x]))
    output.append(sortthreshold(y[0], 0.2))

谢谢!

2 个答案:

答案 0 :(得分:2)

您的输入数据是不等长的Numpy数组的普通Python列表,因此不能简单地将其转换为2D Numpy数组,因此Numpy不能直接对其进行处理。但是可以使用常规的Python列表处理工具进行处理。

这是一个列表理解,它使用numpy.isin测试行是否包含(3,7,8)中的任何一个。我们首先使用简单的==测试来查看该行是否包含10,并且只有在这样做的情况下才调用isin。如果第一个操作数是假的,则Python and运算符将不会计算其第二个操作数。

我们使用np.any来查看是否有任何行项目通过了每个测试。 np.any返回布尔值FalseTrue,但是我们可以将这些值传递给int以将其转换为0或1。

import numpy as np

data = [
    np.array([10, 1, 7, 3]), np.array([0, 14, 12, 13]),
    np.array([3, 10, 7, 8]), np.array([7, 5]),
    np.array([5, 12, 3]), np.array([14, 8, 10]),
]

mask = np.array([3, 7, 8])
result = [int(np.any(row==10) and np.any(np.isin(row, mask)))
    for row in data]

print(result)

输出

[1, 0, 1, 0, 0, 1] 

我刚刚进行了一些timeit测试。奇怪的是,Reblochon Masque的代码在问题中给出的数据上速度更快,大概是由于普通Python anyandor的短路行为。而且,即使文档建议在新代码中使用numpy.in1d,看来timeit test也比numpy.isin快。

这是一个新版本,比Reblochon的版本慢10%。

mask = np.array([3, 7, 8])
result = [int(any(row==10) and any(np.in1d(row, mask)))
    for row in data]

当然,对大量真实数据的真实速度可能与我的测试所表明的有所不同。时间也许不是问题:即使是在速度较慢的旧32位单核2GHz机器上,我也可以在一秒钟内处理问题中的数据近3000次。


hpaulj提出了一种更快的方法。以下是一些3d-party lib信息,用于比较各种版本。这些测试是在我的旧计算机YMMV上进行的。

import numpy as np
from timeit import Timer

the_data = [
    np.array([10, 1, 7, 3]), np.array([0, 14, 12, 13]),
    np.array([3, 10, 7, 8]), np.array([7, 5]),
    np.array([5, 12, 3]), np.array([14, 8, 10]),
]

def rebloch0(data):
    result = []
    for output in data:
        result.append(1 if np.where((any(output == 10) and any(output == 7)) or
            (any(output == 10) and any(output == 3)) or
            (any(output == 10) and any(output == 8)), 1, 0) == True else 0)
    return result

def rebloch1(data):
    result = []
    for output in data:
        result.append(1 if np.where((any(output == 10) and any(output == 7)) or
            (any(output == 10) and any(output == 3)) or
            (any(output == 10) and any(output == 8)), 1, 0) else 0)
    return result

def pm2r0(data):
    mask = np.array([3, 7, 8])
    return [int(np.any(row==10) and np.any(np.isin(row, mask)))
        for row in data]

def pm2r1(data):
    mask = np.array([3, 7, 8])
    return [int(any(row==10) and any(np.in1d(row, mask)))
        for row in data]

def hpaulj0(data):
    mask=np.array([3, 7, 8])
    return [int(any(row==10) and any((row[:, None]==mask).flat))
        for row in data]

def hpaulj1(data, mask=np.array([3, 7, 8])):
    return [int(any(row==10) and any((row[:, None]==mask).flat))
        for row in data]

functions = (
    rebloch0,
    rebloch1,
    pm2r0,
    pm2r1,
    hpaulj0,
    hpaulj1,
)

# Verify that all functions give the same result
for func in functions:
    print('{:8}: {}'.format(func.__name__, func(the_data)))
print()

def time_test(loops, data):
    timings = []
    for func in functions:
        t = Timer(lambda: func(data))
        result = sorted(t.repeat(3, loops))
        timings.append((result, func.__name__))
    timings.sort()
    for result, name in timings:
        print('{:8}: {:.6f}, {:.6f}, {:.6f}'.format(name, *result))
    print()

time_test(1000, the_data)

典型输出

rebloch0: [1, 0, 1, 0, 0, 1]
rebloch1: [1, 0, 1, 0, 0, 1]
pm2r0   : [1, 0, 1, 0, 0, 1]
pm2r1   : [1, 0, 1, 0, 0, 1]
hpaulj0 : [1, 0, 1, 0, 0, 1]
hpaulj1 : [1, 0, 1, 0, 0, 1]

hpaulj1 : 0.140421, 0.154910, 0.156105
hpaulj0 : 0.154224, 0.154822, 0.167101
rebloch1: 0.281700, 0.282764, 0.284599
rebloch0: 0.339693, 0.359127, 0.375715
pm2r1   : 0.367677, 0.368826, 0.371599
pm2r0   : 0.626043, 0.628232, 0.670199

好工作,hpaulj!

答案 1 :(得分:1)

您需要结合使用np.anynp.where,并避免使用python中的二进制运算符|&

import numpy as np

a = [np.array([10,  1,  7,  3]),
     np.array([ 0, 14, 12, 13]),
     np.array([ 3, 10,  7,  8]),
     np.array([7, 5]),
     np.array([ 5, 12,  3]),
     np.array([14,  8, 10])]

for output in a:
    print(np.where(((any(output == 10) and any(output == 7))) or 
                   (any(output == 10) and any(output == 3)) or
                   (any(output == 10) and any(output == 8 )), 1, 0))

输出:

1
0
1
0
0
1

如果您希望将其作为已编辑问题显示的列表:

result = []
for output in a:
    result.append(1 if np.where(((any(output == 10) and any(output == 7))) or 
                   (any(output == 10) and any(output == 3)) or
                   (any(output == 10) and any(output == 8 )), 1, 0) == True else 0)

result

结果:

[1, 0, 1, 0, 0, 1]