我有一个非常简单的代码如下:
import numpy as np
num_classes = 12
im_pred = np.random.randint(0, num_classes, (224, 244))
img = np.zeros((224, 224, 3))
print(im_pred.shape)
#(224, 244)
print(img.shape)
#(224, 224, 3)
for i in range(num_classes):
img[np.where(im_pred==i), :] = [225, 0, 0]
追踪(最近的呼叫最后):
文件"",第2行,在< module>中 IndexError:索引227超出轴0的范围,大小为224
x, y = np.where(im_pred==i)
print(np.max(x), np.max(y))
#223 243
为什么我得到IndexError
?至于我对np.where
的理解,返回的索引值应小于224
。
让我知道。我开始怀疑numpy
安装是否有问题。
感谢。
答案 0 :(得分:1)
No Numpy不是马车。看看你如何定义im_pred一秒钟,你正在绘制一个0到11之间的随机整数,对于一个大小为224乘244的数组。所以它抛出错误的原因是因为244的尺寸对你来说太大了变量img只有224乘224乘以3.我认为你可能意味着两者都有相同的第一维和第二维,比如
img = np.zeros((224,244,3))
答案 1 :(得分:1)
问题在于您制作了不同大小的img
和img_pred
:
im_pred.shape == (224, 244)
,而
img.shape == (224, 224, 3)
第二轴有不同的尺寸。
但是一旦你解决了这个问题,就可以进行简单的优化。这里不需要np.where
。只需使用直接逻辑索引:
for i in range(num_classes):
img[im_pred == i, 0] = 255
注意我也没有留下两个零,因为你在构造时用零初始化数组。