我遇到了一个非常奇怪的问题。
请参阅以下代码。在课堂上有一个奇怪的属性'与所有其他属性一样定义。 (由
标记的代码的问题区域但是当我训练在其中一个函数中访问它时,它会引发一个错误,说该属性不存在。
但是,如果删除在函数中访问该属性的行,然后创建一个对象,则可以访问该对象的属性...
请你看看代码,看看出了什么问题,我在这里结束了......:'(
class img_iter(DataIter):
def __init__(self, dir_train,
dir_label,
data_name = "data",
label_name = "softmax_label",
last_batch_handle = 'pad',
batch_size = 1,
#===================================================
#===================================================
color_mean = (117, 117, 117),
#===================================================
#===================================================
cut_off_size = None):
super().__init__()
# directories as attributes
self.dir_train = dir_train
self.dir_label = dir_label
# names
self.data_name = data_name
self.label_name = label_name
# read data and label files into list
self.img_data_lst = [s for s in os.listdir(dir_train) if '.jpg' in s]
self.img_data_iter = iter(self.img_data_lst)
if self.dir_label is not None:
self.img_label_lst = [s for s in os.listdir(dir_label) if '.gif' in s]
# number of data files
self.num_data = len(self.img_data_lst)
# read data when initialising
self.data, self.label = self._read()
# size limiter
self.cut_off_size = cut_off_size
# batch size
self.batch_size = batch_size
# data cursor
self.cursor = -batch_size
#===================================================
#===================================================
self.weird = np.array(color_mean)
#===================================================
#===================================================
self.last_batch_handle = last_batch_handle
def _read(self):
"""get two list, each list contains two elements: name and nd.array
value"""
data = {}
label = {}
data[self.data_name], label[self.label_name] = self._read_img()
return list(data.items()), list(label.items())
def _read_img(self):
# get next data file from the file name iterator
img_name = self.img_data_iter.__next__()
# open image file
img = Image.open(os.path.join(self.dir_train, img_name))
# find corresponding label image and open [s for s in self.img_label_lst if img_name in s][0]
label_name = img_name.split('.')[0] + '_mask.gif'
label = Image.open(os.path.join(self.dir_label, label_name))
# check image file size match
assert img.size == label.size
# convert into numpy array and manipulate, resulting 3d array: height, width, color
img = np.array(img, dtype = np.float32)
#===================================================
#===================================================
#img = img - self.weird.reshape(1,1,3)
test = self.weird
#===================================================
#===================================================
img = np.swapaxes(img, 0, 2)
# (c, h, w)
img = np.swapaxes(img, 1, 2)
# (1, c, h, w)
img = np.expand_dims(img, axis=0)
# resulting 2d array: height, width
label = np.array(label)
# (h, w)
label = np.array(label)
# (1, h, w)
label = np.expand_dims(label, axis=0)
return (img, label)
def reset(self):
self.cursor = -1
self.img_data_iter = iter(self.img_data_lst)
def iter_next(self):
self.cursor += 1
if (self.cursor < self.num_data - 1):
return True
else:
return False
def next(self):
if self.iter_next():
'''
try:
self.data, self.label = self._read()
return {self.data_name : self.data[0][1],
self.label_name : self.label[0][1]}
except:
raise
'''
self.data, self.label = self._read()
return DataBatch(data = self.data, label = self.label, \
pad=self.getpad(), index=None)
else:
raise StopIteration
@property
def provide_data(self):
"""The name and shape of data provided by this iterator"""
return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.data]
@property
def provide_label(self):
"""The name and shape of label provided by this iterator"""
return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label]
def get_batch_size(self):
return 1
截图:
调用函数中的class属性。
删除对class属性的调用,创建一个对象并直接访问相同的attribte。
答案 0 :(得分:0)
您在定义之前尝试访问self.weird
。
self.data, self.label = self._read()
# self._read() calls self. _read_img()
# which tries to access self.weird
# => AttributeError...
#...
# move this line above `self.data...`,
# and all should be well!
self.weird = np.array(color_mean)
希望这有帮助!