np.where IndexError异常

时间:2018-05-03 03:11:04

标签: python numpy

我有一个非常简单的代码如下:

import numpy as np
num_classes = 12
im_pred = np.random.randint(0, num_classes, (224, 244))
img = np.zeros((224, 224, 3))
print(im_pred.shape)
#(224, 244)
print(img.shape)
#(224, 224, 3)
for i in range(num_classes):
    img[np.where(im_pred==i), :] = [225, 0, 0]
  

追踪(最近的呼叫最后):
        文件"",第2行,在< module>中       IndexError:索引227超出轴0的范围,大小为224

x, y = np.where(im_pred==i)
print(np.max(x), np.max(y))
#223 243

为什么我得到IndexError?至于我对np.where的理解,返回的索引值应小于224

让我知道。我开始怀疑numpy安装是否有问题。

感谢。

2 个答案:

答案 0 :(得分:1)

No Numpy不是马车。看看你如何定义im_pred一秒钟,你正在绘制一个0到11之间的随机整数,对于一个大小为224乘244的数组。所以它抛出错误的原因是因为244的尺寸对你来说太大了变量img只有224乘224乘以3.我认为你可能意味着两者都有相同的第一维和第二维,比如

img = np.zeros((224,244,3)) 

答案 1 :(得分:1)

问题在于您制作了不同大小的imgimg_pred

im_pred.shape == (224, 244)

,而

img.shape == (224, 224, 3)

第二轴有不同的尺寸。

但是一旦你解决了这个问题,就可以进行简单的优化。这里不需要np.where。只需使用直接逻辑索引:

for i in range(num_classes):
    img[im_pred == i, 0] = 255

注意我也没有留下两个零,因为你在构造时用零初始化数组。