我有一个形状为n×2的numpy数组,一堆长度为2的元组,我想转移到SortedList。因此,我们的目标是使用长度为2的整数元组创建一个SortedList。
问题是SortedList的构造函数检查每个条目的真值。这适用于一维数组:
In [1]: import numpy as np
In [2]: from sortedcontainers import SortedList
In [3]: a = np.array([1,2,3,4])
In [4]: SortedList(a)
Out[4]: SortedList([1, 2, 3, 4], load=1000)
但是对于两个维度,当每个条目都是一个数组时,没有明确的真值,且SortedList不合作:
In [5]: a.resize(2,2)
In [6]: a
Out[6]:
array([[1, 2],
[3, 4]])
In [7]: SortedList(a)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-7-7a4b2693bb52> in <module>()
----> 1 SortedList(a)
/home/me/miniconda3/envs/env/lib/python3.6/site-packages/sortedcontainers/sortedlist.py in __init__(self, iterable, load)
81
82 if iterable is not None:
---> 83 self._update(iterable)
84
85 def __new__(cls, iterable=None, key=None, load=1000):
/home/me/miniconda3/envs/env/lib/python3.6/site-packages/sortedcontainers/sortedlist.py in update(self, iterable)
176 _lists = self._lists
177 _maxes = self._maxes
--> 178 values = sorted(iterable)
179
180 if _maxes:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我目前的解决方法是手动将每一行转换为元组:
sl = SortedList()
for t in np_array:
x, y = t
sl.add((x,y))
但是,此解决方案提供了一些改进空间。有没有人知道如何解决这个问题,而无需将所有数组显式解包为元组?
答案 0 :(得分:1)
问题不在于检查数组的真值,而是对它们进行比较,以便对它们进行排序。如果在数组上使用比较运算符,则会返回数组:
>>> import numpy as np
>>> np.array([1, 4]) < np.array([2, 3])
array([ True, False], dtype=bool)
这个结果的布尔数组实际上是sorted
检查其真值的数组。
另一方面,与元组(或列表)相同的操作将逐元素进行元素比较并返回单个布尔值:
>>> (1, 4) < (2, 3)
True
>>> (1, 4) < (1, 3)
False
因此当SortedList
尝试在sorted
数组序列上使用numpy
时,它无法进行比较,因为它需要通过比较返回单个布尔值运算符。
抽象这种方法的一种方法是创建一个新的数组类,它实现比较运算符,如__eq__
,__lt__
,__gt__
等,以重现元组的排序行为。具有讽刺意味的是,最简单的方法是将基础数组转换为元组,如:
class SortableArray(object):
def __init__(self, seq):
self._values = np.array(seq)
def __eq__(self, other):
return tuple(self._values) == tuple(other._values)
# or:
# return np.all(self._values == other._values)
def __lt__(self, other):
return tuple(self._values) < tuple(other._values)
def __gt__(self, other):
return tuple(self._values) > tuple(other._values)
def __le__(self, other):
return tuple(self._values) <= tuple(other._values)
def __ge__(self, other):
return tuple(self._values) >= tuple(other._values)
def __str__(self):
return str(self._values)
def __repr__(self):
return repr(self._values)
通过此实现,您现在可以对SortableArray
个对象列表进行排序:
In [4]: ar1 = SortableArray([1, 3])
In [5]: ar2 = SortableArray([1, 4])
In [6]: ar3 = SortableArray([1, 3])
In [7]: ar4 = SortableArray([4, 5])
In [8]: ar5 = SortableArray([0, 3])
In [9]: lst1 = [ar1, ar2, ar3, ar4, ar5]
In [10]: lst1
Out[10]: [array([1, 3]), array([1, 4]), array([1, 3]), array([4, 5]), array([0, 3])]
In [11]: sorted(lst1)
Out[11]: [array([0, 3]), array([1, 3]), array([1, 3]), array([1, 4]), array([4, 5])]
这可能对你所需要的东西有些过分,但这只是一种方法。在任何情况下,您都不会在一系列对象上使用sorted
,而这些对象在比较时不会返回单个布尔值。
如果您之后所有人都在避免使用for循环,那么您只需将其替换为列表解析(即SortedList([tuple(row) for row in np_array])
)。