如何压扁numpy切片?

时间:2013-03-28 19:55:50

标签: python numpy

我正在实现numpy的ndarray的子​​类,我需要修改__getitem__以从数组的展平表示中获取项目。问题是__getitem__可以使用整数索引或多维切片调用。

是否有人知道如何将多维切片转换为展平数组上的索引列表(或单维切片)?

1 个答案:

答案 0 :(得分:3)

可能无法将多维切片转换为平切片,例如:

>>> a = np.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])
>>> a[::3, 1::2]
array([[ 1,  3],
       [13, 15]])

并且您无法使用[ 1, 3, 13, 15]表示法访问子数组start:stop:step。但是你可以从多维索引构建一个平面索引列表,执行如下操作:

>>> row_idx = np.arange(4)[::3]
>>> col_idx = np.arange(4)[1::2]
>>> row_idx = np.repeat(row_idx, 2)
>>> col_idx = np.tile(col_idx, 2)
>>> np.ravel_multi_index((row_idx, col_idx), dims=(4,4))
array([ 1,  3, 13, 15], dtype=int64)

在更一般的设置中,一旦你有每个维度的索引数组,你需要修改所有索引数组的笛卡尔积,所以itertools.product可能是要走的路。例如:

>>> indices = [np.array([0, 4, 8]), np.array([1,7]), np.array([3, 5, 9])]
>>> indices = zip(*itertools.product(*indices))
>>> indices
[(0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 8),
 (1, 1, 1, 7, 7, 7, 1, 1, 1, 7, 7, 7, 1, 1, 1, 7, 7, 7),
 (3, 5, 9, 3, 5, 9, 3, 5, 9, 3, 5, 9, 3, 5, 9, 3, 5, 9)]
>>> np.ravel_multi_index(indices, dims=(10, 11, 12))
array([  15,   17,   21,   87,   89,   93,  543,  545,  549,  615,  617,
        621, 1071, 1073, 1077, 1143, 1145, 1149], dtype=int64)