在一周中的奇数天,我几乎理解numpy中的多维索引。 Numpy有一个函数'take',它似乎做了我想要的,但有额外的好处,我可以控制如果索引超出范围会发生什么 具体来说,我有一个三维数组要求作为查找表
lut = np.ones([13,13,13],np.bool)
和一个2x2的3长矢量数组,作为表中的索引
arr = np.arange(12).reshape([2,2,3]) % 13
IIUC,如果我要编写lut[arr]
,则arr
被视为2x2x3数字数组,当这些数字用作lut
的索引时,它们每个都返回一个13x13数组。这解释了为什么lut[arr].shape is (2, 2, 3, 13, 13)
。
我可以通过编写
来实现我的目标lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] #(is there a better way to write this?)
现在这三个术语就好像它们已被压缩以生成一个2x2元组数组而lut[<tuple>]
从lut
生成一个元素。最终结果是来自lut
的2x2条目数组,正如我想要的那样。
我已阅读'take'功能的文档......
此功能与“花式”索引相同 (使用数组索引数组);但是,它可以更容易 如果您需要沿给定轴的元素,请使用。
和
轴:int,可选
用于选择值的轴。
也许天真,我认为设置axis=2
我会得到三个值作为3元组来执行查找,但实际上
np.take(lut,arr).shape = (2, 2, 3)
np.take(lut,arr,axis=0).shape = (2, 2, 3, 13, 13)
np.take(lut,arr,axis=1).shape = (13, 2, 2, 3, 13)
np.take(lut,arr,axis=2).shape = (13, 13, 2, 2, 3)
所以我很清楚我不明白发生了什么。任何人都可以告诉我如何实现我想要的目标吗?
答案 0 :(得分:2)
我们可以计算线性指数,然后使用np.take
-
np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
如果您对替代方案持开放态度,我们可以将indices数组重新整形为2D
,转换为元组,使用它索引到数据数组中,给我们1D
,这可以重新转换为2D
-
lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
示例运行 -
In [49]: lut = np.random.randint(11,99,(13,13,13))
In [50]: arr = np.arange(12).reshape([2,2,3])
In [51]: lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] # Original approach
Out[51]:
array([[41, 21],
[94, 22]])
In [52]: np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
Out[52]:
array([[41, 21],
[94, 22]])
In [53]: lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
Out[53]:
array([[41, 21],
[94, 22]])
我们可以避免np.take
方法的双重转置,就像这样 -
In [55]: np.take(lut, np.ravel_multi_index(arr.transpose(2,0,1), lut.shape))
Out[55]:
array([[41, 21],
[94, 22]])
推广到通用维度的多维数组
这可以推广到通用编号的ndarray。昏暗的,像这样 -
np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
tuple-based
方法应该无任何变化。
这是一个样本运行 -
In [95]: lut = np.random.randint(11,99,(13,13,13,13))
In [96]: arr = np.random.randint(0,13,(2,3,4,4))
In [97]: lut[ arr[:,:,:,0] , arr[:,:,:,1],arr[:,:,:,2],arr[:,:,:,3] ]
Out[97]:
array([[[95, 11, 40, 75],
[38, 82, 11, 38],
[30, 53, 69, 21]],
[[61, 74, 33, 94],
[90, 35, 89, 72],
[52, 64, 85, 22]]])
In [98]: np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
Out[98]:
array([[[95, 11, 40, 75],
[38, 82, 11, 38],
[30, 53, 69, 21]],
[[61, 74, 33, 94],
[90, 35, 89, 72],
[52, 64, 85, 22]]])
答案 1 :(得分:0)
最初的问题是尝试在表中进行查找但是 一些索引超出范围,我想控制 发生这种情况时的行为。
import numpy as np
lut = np.ones((5,7,11),np.int) # a 3-dimensional lookup table
print("lut.shape = ",lut.shape ) # (5,7,11)
# valid points are in the interior with value 99,
# invalid points are on the faces with value 0
lut[:,:,:] = 0
lut[1:-1,1:-1,1:-1] = 99
# set up an array of indexes with many of them too large or too small
start = -35
arr = np.arange(start,2*11*3+start,1).reshape(2,11,3)
# This solution has the advantage that I can understand what is going on
# and so I can amend it if I need to
# split arr into tuples along axis=2
arrchannels = arr[:,:,0],arr[:,:,1],arr[:,:,2]
# convert into a flat array but clip the values
ravelledarr = np.ravel_multi_index(arrchannels, lut.shape, mode='clip')
# and now turn back into a list of numpy arrays
# (not an array of the original shape )
clippedarr = np.unravel_index( ravelledarr, lut.shape)
print(clippedarr[0].shape,"*",len(clippedarr)) # produces (2, 11) * 3
# and now I can do the lookup with the indexes clipped to fit
print(lut[clippedarr])
# these are more succinct but opaque ways of doing the same
# due to @Divakar and @hjpauli respectively
print( np.take(lut, np.ravel_multi_index(arr.T, lut.shape, mode='clip')).T )
print( lut.flat[np.ravel_multi_index(arr.T, lut.shape, mode='clip')].T )
实际的应用是我有一个rgb图像,其中包含一些带有一些标记的纹理木材,我已经确定了它的一块。我想获取此补丁中的像素集,并在整个图像中标记与其中一个匹配的所有点。 256x256x256存在表太大,所以我在补丁的像素上运行聚类算法,并为每个聚类设置存在表(补丁中的颜色通过rgb-或hsv-space形成细长线程,因此聚类周围的框很小)。
我使存在表略大于需要,并用False填充每个面。
一旦我设置了这些小存在表,我现在可以通过查找表中的每个像素来测试图像的其余部分以匹配补丁,并使用剪辑来制作通常不会映射到表中的像素实际上映射到表的一个面(并获得值'False')
答案 2 :(得分:0)
我没有尝试三维。但是在 2维中,我使用 numpy.take 获得了我想要的结果:
np.take(np.take(T,ix,axis=0), iy,axis=1 )
也许您可以将其扩展为三维。
作为示例,我可以使用两个1-dim数组作为索引ix,而iy是2维模板用于离散拉普拉斯方程,
ΔT = T[ix-1,iy] + T[ix+1, iy] + T[ix,iy-1] + T[ix,iy+1] - 4*T[ix,iy]
介绍更精简的写作:
def q(Φ,kx,ky):
return np.take(np.take(Φ,kx,axis=0), ky,axis=1 )
然后我可以使用numpy.take运行以下 python代码:
nx = 6; ny= 10
T = np.arange(nx*ny).reshape(nx, ny)
ix = np.linspace(1,nx-2,nx-2,dtype=int)
iy = np.linspace(1,ny-2,ny-2,dtype=int)
ΔT = q(T,ix-1,iy) + q(T,ix+1,iy) + q(T,ix,iy-1) + q(T,ix,iy+1) - 4.0 * q(T,ix,iy)