随机播放自定义图像数据生成器on_epoch_end

时间:2019-06-24 16:18:18

标签: python-3.x tensorflow keras generator

我正在尝试编写自定义图像数据生成器 自定义类继承自keras.utils.Sequence 但我在“ on_epoch_end”上遇到错误, 说没有足够的值要解压

class CityscapesGenerator(Sequence):
    def __init__(self, folder='/cityscapes_reordered', mode='train', n_classes=20, batch_size=1, resize_shape=(2048, 1024),
                 crop_shape=(2048, 1024), horizontal_flip=True, vertical_flip=False, brightness=0.1, rotation=5.0,
                 zoom=0.1):

        self.image_path_list = sorted(glob.glob(os.path.join(folder, 'img',mode, 'png/*')))
        self.label_path_list = sorted(glob.glob(os.path.join(folder, 'label',mode, 'png/*')))
        #edge
        self.edge_path_list = sorted(glob.glob(os.path.join(folder, 'edge',mode, 'png/*')))
        self.mode = mode
        self.n_classes = n_classes
        self.batch_size = batch_size
        self.resize_shape = resize_shape
        self.crop_shape = crop_shape
        self.horizontal_flip = horizontal_flip
        self.vertical_flip = vertical_flip
        self.brightness = brightness
        self.rotation = rotation
        self.zoom = zoom
        .
        .

    def __len__(self):
        return len(self.image_path_list) // self.batch_size

    def __getitem__(self, i):
        for n, (image_path, label_path,edge_path) in enumerate(
                zip(self.image_path_list[i * self.batch_size:(i + 1) * self.batch_size],
                    self.label_path_list[i * self.batch_size:(i + 1) * self.batch_size],
                    self.edge_path_list[i * self.batch_size:(i + 1) * self.batch_size])):

            image = cv2.imread(image_path, 1)
            label = cv2.imread(label_path, 0)
            edge = cv2.imread(label_path, 0)
            combine = np.zeros((1024, 2048, 4))
            combine[:, :, :3] = image
            combine[:, :, -1] = edge
            image=combine
            if self.resize_shape:
                ....
            # Do augmentation (only if training)
            if self.mode == 'training':
                if self.horizontal_flip and random.randint(0, 1):
                    ....
                if self.vertical_flip and random.randint(0, 1):
                    .....
                if self.brightness:
                    .....
                    if random.randint(0, 1):
                        ....
                if self.rotation:
                    .....
                else:
                    .....
                if self.zoom:
                    .....
                else:
                    .....
                if self.rotation or self.zoom:
                    .....
                if self.crop_shape:
                    .....

            self.X1[n] = image
            #edge
            # self.X2[n] = edge

            self.Y1[n] = to_categorical(cv2.resize
                                        (label,(label.shape[1] // 4, label.shape[0] // 4)),
                                        num_classes=self.n_classes).reshape((label.shape[0] // 4, label.shape[1] // 4, -1))
            self.Y2[n] = to_categorical(cv2.resize(label, (label.shape[1] // 8, label.shape[0] // 8)),
                                        num_classes=self.n_classes).reshape((label.shape[0] // 8, label.shape[1] // 8, -1))
            self.Y3[n] = to_categorical(cv2.resize(label, (label.shape[1] // 16, label.shape[0] // 16)),
                                        num_classes=self.n_classes).reshape((label.shape[0] // 16, label.shape[1] // 16, -1))

            # edge
            # self.Y4[n] = to_categorical(cv2.resize(label, (label.shape[1] // 4, label.shape[0] // 4)),
            #                             self.n_classes).reshape((label.shape[0] // 16, label.shape[1] // 16, -1))

        return self.X1, [self.Y1, self.Y2, self.Y3]

    def on_epoch_end(self):
        # Shuffle dataset for next epoch
        c = list(zip(self.image_path_list, self.label_path_list,self.edge_path_list))
        random.shuffle(c)
        self.image_path_list, self.label_path_list,self.edge_path_list = zip(*c)

        # Fix memory leak (tensorflow.python.keras bug)
        gc.collect()

这是我得到的错误:

Traceback (most recent call last):
  File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/stu953839035/.local/lib/python3.6/site-packages/tensorflow/python/keras/utils/data_utils.py", line 634, in _run
    self.sequence.on_epoch_end()
  File "/home/stu953839035/Desktop/Keras-ICNet/utils_edited.py", line 143, in on_epoch_end
    self.image_path_list, self.label_path_list,self.edge_path_list = zip(*c)
ValueError: not enough values to unpack (expected 3, got 0)

我检查了我的代码很多次,甚至在主程序之外模拟了on_epoch_end,结果还不错!

1 个答案:

答案 0 :(得分:0)

我建议检查Transpose/Unzip Function (inverse of zip)?。这说明了zip(* arg)无法产生预期结果的几种情况。

第一次调用begin transaction update ShoppingCart SET EpayOrder = '1789614' WHERE EpayOrder in ( select top 188 EpayOrder from ShoppingCart s inner join Animal a on a.AnimalPK = s.AnimalFK inner join Users u on s.UserFK = u.UserPK where EpayOrder = '1789614' --and s.Deleted is null ) commit rollback 可能会将self._path_list转换为元组。这可能会在以后造成麻烦。