NumPy:排序3D数组,但保留第一个第二维

时间:2016-03-23 14:20:44

标签: sorting numpy

我有一个数组代表玩家持有的扑克牌。每位玩家持有6张牌,牌数为1-12,并且相应的牌照为1-4。

例如,第一位玩家将持有以下7张牌:

deck=np.array([[[  6.,   2.],
                [ 10.,   1.],
                [  5.,   1.],
                [  9.,   2.],
                [  4.,   1.],
                [  3.,   2.],
                [ 11.,   2.]]])

我现在的问题是,当我对卡片进行分类以查看哪一张卡片具有最高价值时(在这种情况下为11张相应的套装2)

sortedcards=-np.sort(-unsortedCards,axis=1)

它不仅对第一列中的值进行排序,还对第二列中的值进行排序(这是诉讼)。

如何才对第一列进行排序,并将第二列分配给第一列,以便我不会丢失哪些值适合的信息?

请记住,上面的例子只有一个玩家,但会有几个玩家。所以数组有一个额外的维度。

重要提示:解决方案必须仅通过纯NumPy矩阵运算。

2 个答案:

答案 0 :(得分:1)

首先,您需要一个可用于对卡片进行分类的值。 一个简单的方法是value*4 + suit

sortval = deck[:,:,0]*4+deck[:,:,1]
sortval *= -1 # if you want largest first

然后使用np.argsort找出哪个索引属于哪里,并使用它来对您的套牌进行排序。它在默认情况下沿最后一个轴排序,这就是我们想要的。

sortedIdx = np.argsort(sortval)

现在您可以使用它来对您的套牌进行分类:

deck = deck[np.arange(len(deck))[:,np.newaxis],sortedIdx]

np.arange...部分确保sortedIdx中的每个第二维索引数组与正确的第一维索引配对。

整件事:

import numpy as np

deck = np.array([[[  6.,   2.],
                  [ 10.,   1.],
                  [  5.,   1.],
                  [  9.,   2.],
                  [  4.,   1.],
                  [  3.,   2.],
                  [ 11.,   2.]],

                 [[  6.,   2.],
                  [  2.,   2.],
                  [  2.,   3.],
                  [ 11.,   1.],
                  [ 11.,   3.],
                  [  5.,   3.],
                  [  4.,   4.]]])

sortval = deck[:,:,0]*4+deck[:,:,1]
sortval *= -1 # if you want largest first
sortedIdx = np.argsort(sortval)
deck = deck[np.arange(len(deck))[:,np.newaxis],sortedIdx]
print(deck)

将打印:

[[[ 11.   2.]
  [ 10.   1.]
  [  9.   2.]
  [  6.   2.]
  [  5.   1.]
  [  4.   1.]
  [  3.   2.]]

 [[ 11.   3.]
  [ 11.   1.]
  [  6.   2.]
  [  5.   3.]
  [  4.   4.]
  [  2.   3.]
  [  2.   2.]]]

答案 1 :(得分:0)

您是否只对值进行排序以查看哪个值具有最高值? 因为在这种情况下为什么不使用np.max()?:

deck=np.array([[[  6.,   2.],
                [ 10.,   1.],
                [  5.,   1.],
                [  9.,   2.],
                [  4.,   1.],
                [  3.,   2.],
                [ 11.,   2.]],
            [[  7.,   2.],
                [ 8.,   1.],
                [  1.,   1.],
                [  9.,   2.],
                [  4.,   1.],
                [  3.,   2.],
                [ 12.,   2.]]])

np.max(deck)
Out[4]: 12.0

np.max(deck[0])
Out[5]: 11.0