在NumPy中使用结构化对象数组

时间:2017-07-24 12:15:45

标签: python arrays numpy

说,我有以下结构的(x,y)点数组:

arr = np.array([([1.     ], [2.     ]),
                ([1., 93.], [5., 46.]),
                ([4.     ], [3.     ])],
               dtype=[('x','O'), ('y', 'O')])

即。这些点被分组为这样的最内层数组。最内层数组的大小可以是任意的,但对于x和y,它总是相同的。

我希望能够执行两件事:

a)通过连接内容扩展最里面的数组,因此对于上面的例子,结果如下:

np.array([( 1.,  2.),
          ( 1.,  5.),
          (93., 46.),
          ( 4.,  3.)],
         dtype=[('x','f8'), ('y','f8')])

b)对于每个(最外层)条目选择元素,例如,最大y:

np.array([( 1.,  2.),
          (93., 46.),
          ( 4.,  3.)],
         dtype=[('x','f8'), ('y','f8')])

我相信应该有一种方法可以有效地做到这一点而不使用丑陋的for循环。非常感谢任何帮助。

UPD(a和b使用丑陋的循环):

(arr是帖子开头定义的数组)

A)

np.array([(x_, y_) for x, y in arr for x_, y_ in zip(x, y)], dtype=[('x','f8'), ('y','f8')])

b)中

np.array([(x[np.argmax(np.array(y))], y[np.argmax(np.array(y))]) for x, y in arr],dtype=[('x','f8'), ('y','f8')])

问题还在于,实际上我不只有两个字段(x和y),而是各种类型的77个字段(浮点数,整数,布尔值)......所以这些表达式将会增长到很多行。

1 个答案:

答案 0 :(得分:1)

使用Pandas,您可以使用group值将数据存储在平面DataFrame中,以指示数据来自原始数组的哪一行:

import numpy as np
import pandas as pd
df = pd.DataFrame([
    (0, 1, 2),
    (1, 1, 5),
    (1, 93, 46),
    (2, 4, 3)], dtype='f8', columns=['group', 'x', 'y'])
print(df)
#    group     x     y
# 0    0.0   1.0   2.0
# 1    1.0   1.0   5.0
# 2    1.0  93.0  46.0
# 3    2.0   4.0   3.0

然后第一个操作只是xy列的一部分:

print(df[['x','y']])
#       x     y
# 0   1.0   2.0
# 1   1.0   5.0
# 2  93.0  46.0
# 3   4.0   3.0

可以使用groupby/idxmax完成第二个操作:

print(df.loc[df.groupby('group')['y'].idxmax(), ['x', 'y']])
#       x     y
# 0   1.0   2.0
# 2  93.0  46.0
# 3   4.0   3.0

鉴于结构化的NumPy数组arr,您将不得不循环 列表至少执行一次以执行任何操作。因此,您可以付出一次代价来将数据组织到更好的数据结构中,例如Pandas DataFrame。

以下是将arr转换为df的一种方式:

import numpy as np
import pandas as pd

arr = np.array([([1.     ], [2.     ]),
                ([1., 93.], [5., 46.]),
                ([4.     ], [3.     ])],
               dtype=[('x','O'), ('y', 'O')])

df = pd.DataFrame(arr)
df = (pd.concat({col: df[col].apply(pd.Series).stack() for col in df}, axis=1)
      .reset_index(drop=True))
print(df)

产量

      x     y
0   1.0   2.0
1   1.0   5.0
2  93.0  46.0
3   4.0   3.0