Keras:从自定义生成器获取单个批处理

时间:2019-01-09 19:23:25

标签: python tensorflow keras

我使用https://github.com/keras-team/keras/issues/3386#issuecomment-237555199这里的代码作为模板创建了一个自定义生成器。我想使用.flow而不是使用.flow_from_dataframe。因此,我的生成器如下所示:

def createGenerator( X, I):

    while True:

        # suffled indices    
        idx = np.random.permutation( X.shape[0])
        # create image generator
        datagen=ImageDataGenerator(preprocessing_function=preprocess_input,
                              horizontal_flip = True,
                              rotation_range=20)


        batches = datagen.flow_from_dataframe(shuffle=False,dataframe=X[["url",target]], directory=bilder_resized_dir, x_col="url", y_col=target,
                                            has_ext=False, class_mode="categorical", color_mode=color_mode, target_size=(IMG_SIZE,IMG_SIZE), batch_size=batch_size,classes=class_list)
        idx0 = 0
        for batch in batches:
            idx1 = idx0 + batch[0].shape[0]

            yield [batch[0], I[ idx[ idx0:idx1 ] ]], batch[1]

            idx0 = idx1
            if idx1 >= X.shape[0]:
                break

现在,我想通过检查单个批次的输出来测试此生成器。我尝试像这样使用next()

combo_gen = createGenerator(X1_dataframe ,X2_aux_input)
x,y = next(combo_gen)

但是这给了我以下错误:

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-274-4a67da5d38d9> in <module>
----> 1 x,y = next(combo_gen)

StopIteration: 

问题:

如何从此生成器中获取单个批次?

1 个答案:

答案 0 :(得分:0)

您可以调用自定义生成器的 __getitem__ 方法,例如以 index =0 获取第一批。