如何在numpy中索引更高维度的项目?

时间:2017-11-06 20:26:45

标签: python numpy

我在这个网站上做了一些搜索,但似乎没有类似的问题(或者我的描述太糟糕了,无法搜索)。

我正面临一个问题,我需要从numpy中的多维ndarray中获取一些ndarray。

假设我有

W = np.random.randn(2,2,3,8)

表示CNN中卷积层中的8个2x2x3滤波器。

我想访问第一个过滤器,这是W中的第一个2x2x3。

我试过

 print(W.shape)
 print(W[:,:,:,:c].shape)
 print(W[:,:,:,:c])
 print(W[:,:,:,:c].flatten())

其中c的范围为0 - 7.返回的结果始终为

(2, 2, 3, 8)
(2, 2, 3, 0)
[]
[]

但我希望看到上面索引的2x2x3过滤器结果..

来自以上4行代码......

繁殖:

W = np.random.randn(2,2,3,8)
for c in range(0, 8):
#   print(W.shape)
    print(W[:,:,:,:c].shape)
#   print(W[:,:,:,:c])
#   print(W[:,:,:,:c].flatten())

结果是:

(2, 2, 3, 0)
(2, 2, 3, 1)
(2, 2, 3, 2)
(2, 2, 3, 3)
(2, 2, 3, 4)
(2, 2, 3, 5)
(2, 2, 3, 6)
(2, 2, 3, 7)

我实际上期望8(2,2,3)。 请帮忙!

1 个答案:

答案 0 :(得分:1)

您可以通过切片W

来访问过滤器
import numpy as np

W = np.random.randn(2,2,3,8)
for c in range(0, 8):
    print(W[:,:,:,c].shape)
    print(W[:,:,:,c])

基本上,W[:,:,:,0]是您的第一个过滤器,依此类推。