包含np.arrays的元组的成员资格测试

时间:2015-07-07 03:34:03

标签: python numpy

我试图测试包含标量元组和np.arrays的列表/元组的成员资格。它适用于常规数组,但不适用于np数组。以下打印的第一个打印语句" True",第二个打印ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

huge = [(5.0, [[ 3., -1.],
       [-1.,  2.]], [ 7.,  5.]), (2.0, [[ 2.,  1.],
       [ 1.,  1.]]), [-2.,  5.], (2.0, [[ 1.,  0.],
       [ 0.,  2.]], [ 0.,  1.]), (1.0,[[ 0.2,  0.1],
       [ 0.1,  1. ]], [-3.,  4.])]
lil = (2.0,[[ 1.,  0.],
       [ 0.,  2.]],[ 0.,  1.])

nphuge = [(5.0, np.array([[ 3., -1.],
       [-1.,  2.]]), np.array([ 7.,  5.])), (2.0, np.array([[ 2.,  1.],
       [ 1.,  1.]]), np.array([-2.,  5.])), (2.0, np.array([[ 1.,  0.],
       [ 0.,  2.]]), np.array([ 0.,  1.])), (1.0, np.array([[ 0.2,  0.1],
       [ 0.1,  1. ]]), np.array([-3.,  4.]))]
nplil = (2.0, np.array([[ 1.,  0.],
       [ 0.,  2.]]), np.array([ 0.,  1.]))


print lil in huge #Prints "True"

print nplil in nphuge #Raises ValueError

我可以通过手动将每个元组的成员转换为常规列表而不是np.arrays来解决这个问题:

nplil_work_around = nplil[0],nplil[1].tolist(),nplil[2].tolist()
nphuge_work_around = [(x[0],x[1].tolist(), x[2].tolist()) for x in nphuge]

print nplil_work_around in nphuge_work_around # prints True

有没有办法在不转换np.arrays的情况下执行此操作?

1 个答案:

答案 0 :(得分:0)

您可以使用:

any( all((np.all(f==s) for f, s in zip(nplil, nphuge[i]))) for i in range(len(nphuge)) )

一步一步:

>>> nphuge = [(5.0, np.array([[ 3., -1.],
       [-1.,  2.]]), np.array([ 7.,  5.])), (2.0, np.array([[ 2.,  1.],
       [ 1.,  1.]]), np.array([-2.,  5.])), (2.0, np.array([[ 1.,  0.],
       [ 0.,  2.]]), np.array([ 0.,  1.])), (1.0, np.array([[ 0.2,  0.1],
       [ 0.1,  1. ]]), np.array([-3.,  4.]))]
>>> nplil = (2.0, np.array([[ 1.,  0.],
       [ 0.,  2.]]), np.array([ 0.,  1.]))

In [54]: [np.all(f==s) for f, s in zip(nplil, nphuge[0])]
Out[54]: [False, False, False]

In [55]: [np.all(f==s) for f, s in zip(nplil, nphuge[1])]
Out[55]: [True, False, False]

In [56]: [np.all(f==s) for f, s in zip(nplil, nphuge[2])]
Out[56]: [True, True, True]

>>> [ all((np.all(f==s) for f, s in zip(nplil, nphuge[i]))) for i in range(len(nphuge)) ]
[False, False, True, False]
>>> any( [ all((np.all(f==s) for f, s in zip(nplil, nphuge[i]))) for i in range(len(nphuge)) ] )
True