使用自定义dtype按行对对象数组进行排序

时间:2020-10-05 19:50:32

标签: python numpy sorting

我试图按行对字典进行排序。整数大小写可以完美地工作:

>>> arr = np.random.choice(10, size=(5, 3))
>>> arr
array([[1, 0, 2],
       [8, 0, 8],
       [1, 8, 4],
       [1, 3, 9],
       [6, 1, 8]])
>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()
>>> arr
array([[1, 0, 2],
       [1, 3, 9],
       [1, 8, 4],
       [6, 1, 8],
       [8, 0, 8]])

我也可以使用

进行排序
np.ndarray(arr.shape[0], dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr).sort()

在两种情况下,结果都是相同的。但是,对象数组并非如此:

>>> selection = np.array(list(string.ascii_lowercase), dtype=object)
>>> arr = np.random.choice(selection, size=(5, 3))
>>> arr
array([['t', 'p', 'g'],
       ['n', 's', 'd'],
       ['g', 'g', 'n'],
       ['g', 'h', 'o'],
       ['f', 'j', 'x']], dtype=object)
>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()
>>> arr
array([['t', 'p', 'g'],
       ['n', 's', 'd'],
       ['g', 'h', 'o'],
       ['g', 'g', 'n'],
       ['f', 'j', 'x']], dtype=object)
>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr).sort()
>>> arr
array([['f', 'j', 'x'],
       ['g', 'g', 'n'],
       ['g', 'h', 'o'],
       ['n', 's', 'd'],
       ['t', 'p', 'g']], dtype=object)

很明显,只有dtype=[('', arr.dtype)] * arr.shape[1]的情况才能正常工作。这是为什么? dtype=[('', arr.dtype, arr.shape[1])]有什么不同?排序显然在做些事情,但是乍一看顺序似乎是荒谬的。它使用指针作为排序键吗?

就其价值而言,np.searchsorted似乎在进行与np.sort相同的比较,符合预期。

4 个答案:

答案 0 :(得分:0)

这实际上很好

In [16]: selection = np.array(list(string.ascii_lowercase))

In [17]: arr = np.random.choice(selection, size=(5, 3))

In [18]: arr
Out[18]:
array([['x', 'l', 'i'],
       ['k', 'h', 'b'],
       ['y', 'h', 'w'],
       ['i', 'u', 't'],
       ['v', 'u', 'k']], dtype='<U1')

In [19]: np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()

In [20]: arr
Out[20]:
array([['i', 'u', 't'],
       ['k', 'h', 'b'],
       ['v', 'u', 'k'],
       ['x', 'l', 'i'],
       ['y', 'h', 'w']], dtype='<U1')

问题在于使用dtype object进行选择。

In [21]: selection = np.array(list(string.ascii_lowercase), dtype = object)

In [22]: arr = np.random.choice(selection, size=(5, 3))

In [23]: arr
Out[23]:
array([['b', 'h', 'e'],
       ['o', 'z', 'c'],
       ['g', 'v', 'z'],
       ['r', 'n', 'k'],
       ['a', 'h', 't']], dtype=object)

In [24]: np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()

In [25]: arr
Out[25]:
array([['o', 'z', 'c'],
       ['b', 'h', 'e'],
       ['r', 'n', 'k'],
       ['a', 'h', 't'],
       ['g', 'v', 'z']], dtype=object)

注意dtype = 'O'表示numpy类型的python object see here for more,我认为它不提供比较运算符。

通常您提供的两种类型仍然可以使用。

答案 1 :(得分:0)

排序适用于整数的事实恰好是一个巧合,这可以通过查看浮点运算的结果来验证:

>>> arr = np.array([[0.5, 1.0, 10.2],
                    [0.4, 2.0, 11.0],
                    [1.0, 2.0, 4.0]])
>>> np.sort(np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr))
array([([ 0.5,  1. , 10.2],),
       ([ 1. ,  2. ,  4. ],),
       ([ 0.4,  2. , 11. ],)], dtype=[('f0', '<f8', (3,))])
>>> np.sort(np.ndarray(arr.shape[0], dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr))
array([(0.4, 2., 11. ),
       (0.5, 1., 10.2),
       (1. , 2.,  4. )],
      dtype=[('f0', '<f8'), ('f1', '<f8'), ('f2', '<f8')])

另一个提示来自查看数字 0.50.41.0 的位:

0.5 = 0x3FE0000000000000
0.4 = 0x3FD999999999999A
1.0 = 0x3FF6666666666666

在小端机器上,我们有 0x00 < 0x66 < 0x9A(上面显示的最后一个字节在前)。

可以通过查看 the source code 中的排序函数来验证确切的答案。例如,在 quicksort.c.src 中,我们看到所有不是显式数字的类型(包括不是标量的结构字段)都由 npy_quicksort 泛型函数处理。它使用函数 cmp 作为比较器,使用宏 GENERIC_SWAPGENERIC_COPY 分别进行交换和复制。

函数 cmp 定义为 PyArray_DESCR(arr)->f->compare。宏定义为 npysort_common.h 中的逐元素操作。

所以最后的结果是,对于任何非标量类型,包括压缩数组结构域,都是逐字节进行比较的。对于对象,这当然是指针的数值。对于浮点数,这将是 IEEE-754 表示。正整数似乎正常工作的事实是由于我的平台使用小端编码这一事实造成的。以二进制补码形式存储的负整数可能不会产生正确的结果。

答案 2 :(得分:-1)

这可能不是一个完美的答案,但我希望我能为您提供帮助:

1。)为什么不能正常工作:因为dtype=[('', arr.dtype)] * arr.shape[1]!= dtype=[('', arr.dtype, arr.shape[1])]

2。)两者之间有什么区别?好吧,虽然第一个将长度添加到列表中,但是第二个将列表增加了。 这意味着第一个输出为:[('', dtype('O'), 3)],而第二个输出为[('', dtype('O')), ('', dtype('O')), ('', dtype('O'))]

3。)排序显然做错了-没有输入的格式错误

4。)是否使用指针作为排序键?您是说它是否通过数据键格式化数据?然后不,它会根据数据本身对它们进行排序。

编辑: 可以更清楚地说:

首先,我认为您误解了@Mike MacNeil's anwer。为了使它更具可塑性,请参见以下示例:

让我们考虑一个Foo类:

class Foo:
    def __init__(self, id):
        self._id = id
    
    def get_id(self):
        return self._id

    def __le__(self, ob):
        return self < ob or self == ob

    def __lt__(self, ob):
        return self.get_id() < ob.get_id()

    def __ge__(self, ob):
        return not self < ob

    def __gt__(self, ob):
        return not self <= ob

    def __eq__(self, ob):
        return self.get_id() == ob.get_id()

    def __str__(self):
        return f'Foo({self.get_id()})'

    def __repr__(self):
        rep = super().__repr__()
        return f'{str(self)} {rep[rep.index("at"):rep.index(">")]}'

我们看到比较已经像string中那样实现了。我还实现了__repr__()__str__()方法,这给了我一点时间,您将了解原因:

让我们在第一步中创建一个numpy数组:

>>> arr4 = np.array([[Foo(1), Foo(2), Foo(3)],
        [Foo(4), Foo(5), Foo(6)],
        [Foo(7), Foo(8), Foo(9)],
        [Foo(10), Foo(11), Foo(12)]])

如果我们打印它,它将看起来像这样:

>>> arr4
array([[Foo(1) at 0x000002411F753F08, Foo(2) at 0x000002411F73FF48, Foo(3) at 0x000002411F74EE48],
       [Foo(4) at 0x000002411F74EE88, Foo(5) at 0x000002411F74EE08, Foo(6) at 0x000002411F756148],
       [Foo(7) at 0x000002411F7561C8, Foo(8) at 0x000002411F756208, Foo(9) at 0x000002411F756248],
       [Foo(10) at 0x000002411F756288, Foo(11) at 0x000002411F7562C8,
        Foo(12) at 0x000002411F756308]], dtype=object)

如果我们现在打印ndarray ...

>>> np.ndarray(arr4.shape[0], dtype=[('', arr4.dtype, arr4.shape[1])], buffer=arr4)
array([([Foo(1) at 0x000002411F753F08, Foo(2) at 0x000002411F73FF48, Foo(3) at 0x000002411F74EE48],),
       ([Foo(4) at 0x000002411F74EE88, Foo(5) at 0x000002411F74EE08, Foo(6) at 0x000002411F756148],),
       ([Foo(7) at 0x000002411F7561C8, Foo(8) at 0x000002411F756208, Foo(9) at 0x000002411F756248],),
       ([Foo(10) at 0x000002411F756288, Foo(11) at 0x000002411F7562C8, Foo(12) at 0x000002411F756308],)], dtype=[('f0', 'O', (3,))])

...我们看到它的形状基本上与

>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr)
array([['t', 'p', 'g'],
       ['n', 's', 'd'],
       ['g', 'h', 'o'],
       ['g', 'g', 'n'],
       ['f', 'j', 'x']], dtype=[('f0', 'O', (3,))])

np.ndarray(arr4.shape[0], dtype=[('', arr4.dtype, arr4.shape[1])], buffer=arr4).sort()对Foo-Array排序后,我们看到arr4的输出类似于:

>>> arr4
array([[Foo(1) at 0x000002411F753F08, Foo(2) at 0x000002411F73FF48, Foo(3) at 0x000002411F74EE48],
       [Foo(10) at 0x000002411F756288, Foo(11) at 0x000002411F7562C8, Foo(12) at 0x000002411F756308],
       [Foo(4) at 0x000002411F74EE88, Foo(5) at 0x000002411F74EE08, Foo(6) at 0x000002411F756148],
       [Foo(7) at 0x000002411F7561C8, Foo(8) at 0x000002411F756208, Foo(9) at 0x000002411F756248]], dtype=object)

尽管

>>> Foo(10) > Foo(4)
True

(仍然np.ndarray(arr4.shape[0], dtype=[('', arr4.dtype)] * arr4.shape[1], buffer=arr4).sort()可以使用定义的比较函数打印出按id键排序的预期结果。)

dtype=object的比较规则不像您期望的那样仅使用标准比较功能,而是在比较对象表示形式(→在这种情况下,这意味着repr(Foo(10)) < repr(Foo(2))将是{ {1}},尽管我们实际上希望Foo(10)大于Foo(2))。

但是通过告诉numpy确切的尺寸/形状,numpy使用了标准比较,这将产生预期的结果,因为它现在知道一行中的所有元素都来自完全相同的类型,而不仅仅是一些固定在一起的随机对象分成一个阵列这就是为什么您的示例也不能与True一起使用,但是可以与string一起使用的原因,因为numpy(str本身支持str

答案 3 :(得分:-2)

您的第一个方法dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()似乎正在尝试对dtype = object进行排序,但是它没有足够的信息来对其进行排序。当您使用第二种方法dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr).sort()时,它将解压缩对象,从而使sort方法可以“查看”其应该进行的排序。在标量上使用这些方法时,sort方法可以看到它们是标量,而不是对象。

这全都是我的猜想,但这对我来说很有意义。如果有人可以纠正我,请这样做!