我正在使用jupyter笔记本运行具有以下文件结构的python代码:
文件结构
Main_jupyter_file.ipynb dataloader.py
Main_jupyter_file.ipynb
from dataloader import DataGenerator training_generator = DataGenerator(...) model = ... model.compile(loss=...optimizer='adam') model.fit_generator(generator=training_generator,\ use_multiprocessing=True,workers=8, epochs=2)
dataloader.py
class DataGenerator(tf.keras.utils.Sequence): def __init__(self, ...): def __len__(self): 'Denotes the number of batches per epoch' return ... def __getitem__(self, index): return ...
运行model.fit_generator
后,我收到警告和错误:
警告:tensorflow:使用带有use_multiprocessing=True
的生成器和多个工作线程可能会复制您的数据。请考虑使用keras.utils.Sequence类。
steps_per_epoch=None
仅对基于keras.utils.Sequence
类的生成器有效。请指定steps_per_epoch
或使用keras.utils.Sequence
类。
使用tf.keras和本地Keras可获得相同的结果。
当我将dataGenerator类的代码从dataloader.py复制到jupyter笔记本文档中时,没有任何错误或警告,一切正常。