Numpy:3D布尔索引数组

时间:2016-03-13 13:17:59

标签: python arrays numpy boolean

通常,numpy数组比列表操作或循环快得多,但在这种情况下也是如此?:

我有一个4D数组和一个布尔索引数组用于前三个轴&#39 ;;索引的输出至少在索引轴上变平,因此它是一个元组列表' (但以阵列形式)。

由于常规结构被破坏,我认为这比常规网格的索引要慢得多(即,独立地索引每个轴)?也许numpy内部真的计算元组列表然后将其转换为数组?

为什么我要问:我想枚举输出,以便能够计算任何元组(如果它在列表中)以及在哪个位置。我试着了解哪种方法可以快速而优雅......

我的上下文: 我有一个整数koordinates数组,一个网格 - 所以逻辑上我有一个3元组的3D数组,但对于程序它是一个4D阵列。

我想得到坐标总和等于一个常数的所有点,这是从我的立方体中切出一个平面(最后,我拿了两个相邻的平面,这给了我一个蜂窝格子 - 它' s如果你喜欢数学就很漂亮:))

所以最后一个轴的值​​只是前三个轴的指数'。如果我不仅有一个TrueFalse的索引数组,而且还分配了一个id而不是每个True,那么我可以很容易地读出每个元组的id。

这可能是一种优雅而快速的任务方式(目标是知道其中一个平面中的每个站点哪个站点相邻 - 所以他们的坐标是已知的,但我想要他们的id)。

那么,numpy内部是否有任何魔法来获取索引数组?或者同样快速地采取for-loop;)(不,我通过尝试看到,这更快,但为什么......)

一些代码(德语评论,抱歉)

import numpy as np
Seitenlaenge = 4
kArray = np.zeros((Seitenlaenge, Seitenlaenge, Seitenlaenge, 3)) # 4D-Array, hier soll dann an der Stelle [x, y, z, :] der Vektor (x, y, z) stehen
kArray[:, :, :, 2] = np.arange(Seitenlaenge).reshape((1, 1, Seitenlaenge)).repeat(Seitenlaenge, axis = 0).repeat(Seitenlaenge, axis = 1)
kArray[:, :, :, 1] = np.arange(Seitenlaenge).reshape((1, Seitenlaenge, 1)).repeat(Seitenlaenge, axis = 0).repeat(Seitenlaenge, axis = 2)
kArray[:, :, :, 0] = np.arange(Seitenlaenge).reshape((Seitenlaenge, 1, 1)).repeat(Seitenlaenge, axis = 1).repeat(Seitenlaenge, axis = 2)
# Die Gitterpunkte waehlen die zu A und B gehoeren:

print kArray

Summe = 5 # Seitenlaenge des Dreiecks, das aus dem 1.Oktanten geschnitten wuerde, wenn der Wuerfel nicht kleiner waere
ObA = kArray.sum(axis=-1) == Summe-1 # 3D-boolean Array
ObB = kArray.sum(axis=-1) == Summe-2

print ObA

kA, kB = kArray[ObA], kArray[ObB] # Es bleiben 2D-Arrays: Listen von Koordina-
# tentripeln, in der Form (x, y, z)

print kA

如果你想看到蜂窝格子,那么事后再做:

import matplotlib.pyplot as plt

nx = np.array([-1, 1, 0])*2**-0.5
ny = np.array([-1, -1, 2])*6**-0.5
def Projektion(ListeTripel):
    return dot(ListeTripel, nx), dot(ListeTripel, ny)

xA, yA = Projektion(kA)
xB, yB = Projektion(kB)

plt.plot(xA.flatten(), yA.flatten(), 'o', c='r', ms=8, mew=0)
plt.plot(xB.flatten(), yB.flatten(), 'o', c='b', ms=8, mew=0)

plt.show()

1 个答案:

答案 0 :(得分:2)

Numpy对索引非常聪明。它将展平您的布尔数组,计算nnz,其中True的数量,分配形状(nnz, 3)的输出数组,然后逐项迭代展平的布尔数组,和你的扁平阵列跳跃3项,即3项步幅。无论布尔数组有True,它都会将数组的后3项复制到输出数组,然后继续迭代。

所有这一切都将在C中发生,所以它非常非常快,至少是Python标准。

顺便说一句,与您的问题有点无关,但请使用broadcasting

length = 4
indices = np.arange(length)
k_array = np.empty((length,) * 3 + (3,), dtype=np.intp)
k_array[..., 0] = indices
k_array[... ,1] = indices[:, None]
k_array[... ,2] = indices[:, None, None]