折叠嵌套的数组数组

时间:2012-10-17 06:32:27

标签: python numpy

我想采用形状(N,)dtype=object的数组,这些数组都具有相同的形状shape,并创建一个shape == (N,) + shape的数组。我想知道是否有人知道这样做的最好方法。这是一个例子。

import numpy as np
array = np.empty(4, dtype=object)
array[:] = [np.ones([3, 2])]
array = np.array(array.tolist())
print array.dtype
# float64
print array.shape
# (4, 3, 2)

1 个答案:

答案 0 :(得分:0)

如果您已经知道内部数组的形状(此处为(3,2)),则可以简化整个过程

subshape = (3,2)
a = np.empty(tuple([N,]+list(subshape)), dtype=object)
a[:] = np.ones(subshape)

这样可以避免对列表进行不必要的转换。


现在,假设您有一个(N,)对象数组a,其中每个元素都是subshape浮点数组,您可以这样做:

a = np.vstack(a)
a.shape = [N,] + list(subshape)

或更简单:

a = np.array(a.tolist(), dtype=float)

.tolist转换可能效率不高。