我尝试在类方法中使用@
。像这样
class Dataset:
@parse_func
def get_next_batch(self):
return self.generator.__next__()
和类似的解析函数:
def parse_func(load_batch):
def wrapper(**para):
batch_files_path, batch_masks_path, batch_label = load_batch(**para)
batch_images = []
batch_masks = []
for (file_path, mask_path) in zip(batch_files_path, batch_masks_path):
image = cv2.imread(file_path)
mask = cv2.imread(mask_path)
batch_images.append(image)
batch_masks.append(mask)
return np.asarray(batch_images, np.float32), np.asarray(batch_masks, np.uint8), batch_label
return wrapper
但是,当我打电话给dataset.get_next_batch()
时,它会按如下方式引发exception
。
回溯(最近通话最近): TypeError:wrapper()恰好接受0个参数(给定1个)
您知道为什么会引发此错误以及任何解决方案吗?非常感谢你!
答案 0 :(得分:1)
函数wrapper(**kwargs)
仅接受命名参数。但是,在实例方法中,self
作为 first positional 参数自动传递。由于您的方法不接受位置参数,因此会失败。
您可以编辑为wrapper(self, **kwargs)
,或更一般的wrapper(*args, **kwargs)
。但是,您使用它的方式尚不清楚这些参数是什么。
答案 1 :(得分:0)
只需更改
def parse_func(load_batch):
def wrapper(*para):
batch_files_path, batch_masks_path, batch_label = load_batch(*para)
batch_images = []
batch_masks = []
for (file_path, mask_path) in zip(batch_files_path, batch_masks_path):
image = cv2.imread(file_path)
mask = cv2.imread(mask_path)
batch_images.append(image)
batch_masks.append(mask)
return np.asarray(batch_images, np.float32), np.asarray(batch_masks, np.uint8), batch_label
return wrapper()
@
符号表示装饰器功能。在这里,它的意思是parse_func(get_next_batch)
。因此,如果包装器使用关键字params(**para
),则只想将一些参数传递给包装器,但实际上self
参数除外。因此,在这里我将参数替换为位置参数*para
。