我想在指定的轴中收集指定索引的元素,如下所示。
x = [[1,2,3], [4,5,6]]
index = [[2,1], [0, 1]]
x[:, index] = [[3, 2], [4, 5]]
这基本上是在pytorch中的聚集操作,但是如你所知,这在numpy中是不可能实现的。我想知道在numpy中是否有这样的“聚集”操作?
答案 0 :(得分:3)
我前一段时间写这是为了在Numpy中复制PyTorch的{{1}}。在这种情况下,gather
是您的self
x
这些是测试用例:
def gather(self, dim, index):
"""
Gathers values along an axis specified by ``dim``.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters
----------
dim:
The axis along which to index
index:
A tensor of indices of elements to gather
Returns
-------
Output Tensor
"""
idx_xsection_shape = index.shape[:dim] + \
index.shape[dim + 1:]
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
if idx_xsection_shape != self_xsection_shape:
raise ValueError("Except for dimension " + str(dim) +
", all dimensions of index and self should be the same size")
if index.dtype != np.dtype('int_'):
raise TypeError("The values of index must be integers")
data_swaped = np.swapaxes(self, 0, dim)
index_swaped = np.swapaxes(index, 0, dim)
gathered = np.choose(index_swaped, data_swaped)
return np.swapaxes(gathered, 0, dim)
答案 1 :(得分:1)
>>> x = np.array([[1,2,3], [4,5,6]])
>>> index = np.array([[2,1], [0, 1]])
>>> x_axis_index=np.tile(np.arange(len(x)), (index.shape[1],1)).transpose()
>>> print x_axis_index
[[0 0]
[1 1]]
>>> print x[x_axis_index,index]
[[3 2]
[4 5]]
答案 2 :(得分:1)
numpy.take_along_axis是我所需要的,根据索引获取元素。可以像PyTorch中的collect方法一样使用。
这是手册中的示例:
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai
array([[1],
[0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
[60]])
答案 3 :(得分:0)
使用numpy.take()函数,该函数具有PyTorch的大多数collect函数功能。