keras.utils.Sequence-对象不是迭代器

时间:2019-04-28 10:35:01

标签: python keras

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.epoch = 0
        self.batch = 0
        self.batch_size = batch_size
        self.per = 4

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[int(np.ceil(self.x.shape[0]*(self.per/100)))]
        batch_y = self.y[int(np.ceil(self.x.shape[0]*(self.per/100)))]
        return np.array(batch_x), np.array(batch_y)
        return (batch_x, batch_y)

    def on_batch_end(self):
        if self.epoch % 100 == 0:
            self.per = self.per*1.9
        self.epoch += 1

train_datagen = CIFAR10Sequence(new_x_sort, new_x_sort, 100)
test_datagen = CIFAR10Sequence(cifar100_dataset.x_test,
                               cifar100_dataset.x_test, 100)

model.fit_generator(generator=train_datagen, steps_per_epoch=len(new_x_sort)//100, epochs=20)

但是我得到: TypeError: 'CIFAR10Sequence' object is not an iterator

2 个答案:

答案 0 :(得分:0)

您需要string_split()中的__iter__()函数。 像

CIFAR10Sequence

答案 1 :(得分:0)

在您调用 fit_generator 之前,似乎必须在另一个上下文中引用了 Sequence 对象。该调用不为生成器使用关键字 arg,因此如果您确实到达该调用,则会收到关键字 arg 错误。 Sequence 对象确实有 __iter__() 所以它是可迭代的,但它没有 __next__() 所以它不是一个迭代器,如果它被这样引用,它会抛出那个错误。 __iter__() 是 keras fit 所需的全部。