仅在从其他文件导入后,Keras中的DataGenerator错误

时间:2018-07-10 07:29:13

标签: python tensorflow keras jupyter-notebook

我正在使用jupyter笔记本运行具有以下文件结构的python代码:

  

文件结构

Main_jupyter_file.ipynb
dataloader.py
     

Main_jupyter_file.ipynb

from dataloader import DataGenerator
training_generator = DataGenerator(...)
model = ...
model.compile(loss=...optimizer='adam')
model.fit_generator(generator=training_generator,\
use_multiprocessing=True,workers=8, epochs=2)
     

dataloader.py

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):

    def __len__(self):
    'Denotes the number of batches per epoch'
        return ...

    def __getitem__(self, index):
        return ...

运行model.fit_generator后,我收到警告和错误:

警告:

警告:tensorflow:使用带有use_multiprocessing=True的生成器和多个工作线程可能会复制您的数据。请考虑使用keras.utils.Sequence类。

错误:

steps_per_epoch=None仅对基于keras.utils.Sequence类的生成器有效。请指定steps_per_epoch或使用keras.utils.Sequence类。

使用tf.keras和本地Keras可获得相同的结果。

奇怪的发现

当我将dataGenerator类的代码从dataloader.py复制到jupyter笔记本文档中时,没有任何错误或警告,一切正常。

0 个答案:

没有答案