我目前使用https://github.com/0bserver07/Keras-SegNet-Basic中的keras segnet,运行它没有问题。但是,当我更改为自己的数据集时,出现索引错误问题。我在python和人工智能领域还很新,所以我希望任何人都可以帮助我。我搜索了其他类似的问题和解决方案,但仍然没有得到。
这是错误:
Traceback (most recent call last):
File "C:/Users/ZC/PycharmProjects/segnet(crack)/model.py", line 35, in <module>
train_data, train_label = load_data("train2")
File "C:/Users/ZC/PycharmProjects/segnet(crack)/model.py", line 30, in load_data
label.append(one_hot_it(cv2.imread(txt[i][1][:-1])[:,:,0]))
File "C:\Users\ZC\PycharmProjects\segnet(crack)\helper.py", line 29, in one_hot_it
x[i,j,labels[i][j]]=1
IndexError: index 255 is out of bounds for axis 2 with size 12
这是代码(model.py):
from __future__ import absolute_import
from __future__ import print_function
import cv2
import numpy as np
import itertools
from helper import *
import os
# Copy the data to this dir here in the SegNet project /CamVid from here:
# https://github.com/alexgkendall/SegNet-Tutorial
DataPath = 'C:/Keras-SegNet/SegNet/'
data_shape = 360*480
def load_data(mode):
data = []
label = []
with open(DataPath + mode +'.txt') as f:
txt = f.readlines()
txt = [line.split(' ') for line in txt]
for i in range(len(txt)):
print(txt[i][0])
print(txt[i][1][:-1])
img=cv2.imread( txt[i][1][:-1])
cv2.imshow('image',img)
data.append(np.rollaxis(normalized(cv2.imread( txt[i][0])),2))
label.append(one_hot_it(cv2.imread(txt[i][1][:-1])[:,:,0]))
print('.',end='')
return np.array(data), np.array(label)
train_data, train_label = load_data("train2")
train_label = np.reshape(train_label,(300,data_shape,2))
test_data, test_label = load_data("test")
test_label = np.reshape(test_label,(233,data_shape,2))
np.save("train_data", train_data)
np.save("train_label", train_label)
np.save("test_data", test_data)
np.save("test_label", test_label)
这是helper.py文件代码:
from __future__ import absolute_import
from __future__ import print_function
import cv2
import numpy as np
import itertools
from helper import *
import os
def normalized(rgb):
#return rgb/255.0
norm=np.zeros((rgb.shape[0], rgb.shape[1], 3),np.float32)
b=rgb[:,:,0]
g=rgb[:,:,1]
r=rgb[:,:,2]
norm[:,:,0]=cv2.equalizeHist(b)
norm[:,:,1]=cv2.equalizeHist(g)
norm[:,:,2]=cv2.equalizeHist(r)
return norm
def one_hot_it(labels):
x = np.zeros([360,480,12])
for i in range(360):
for j in range(480):
x[i,j,labels[i][j]]=1
return x
谁能解释一下该函数对helper.py的作用,特别是这行x [i,j,labels [i] [j]] = 1。
答案 0 :(得分:0)
在这里,您可以在创建x
时看到12
数组在第3轴上的大小为np.zeros([360,480,12])
。因此,每当访问第3轴上的某个内容时,都需要确保该内容介于0
和11
之间。
在这种情况下,labels[i][j]
在某个时刻的值为255
,因此它不起作用。这是因为您在图像上使用one_hot_it
而不是标签(one_hot_it(cv2.imread(txt[i][1][:-1])[:,:,0])
)。