我想根据列矩阵从矩阵中的每一行中选择一个元素。 因此,列矩阵包含要选择的索引。
(Pdb) num_samples
15000
(Pdb) probs.shape
(15000, 26)
(Pdb) y.shape
(15000, 1)
(Pdb) (probs[np.arange(num_samples),y]).shape
(15000, 15000)
(Pdb) # this should (15000,)
答案 0 :(得分:1)
integer array indexing可能会有所帮助。
假设你有这个numpy
数组:
myArray = numpy.array([[2, 3, 4],
[6, 7, 8],
[9, 1, 5]])
如果要选择的索引数组是
indices = numpy.array([2, 0, 1])
然后
rowSelector = numpy.arange(myArray.shape[0])
myArray[rowSelector, indices]
返回数组中所选元素的值:
array([4, 6, 1])
答案 1 :(得分:0)
您的y
变量可能是一个numpy矩阵。当y
是列表时,您的代码可以正常工作:
>>> num_samples = 3
>>> data = np.matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> y = [0, 1, 2]
>>> print((data[np.arange(num_samples), y]).shape)
(1, 3)
,结果可以很容易地重新排列成列矩阵。