我正在用keras训练神经网络,并且由于我的数据集非常大,因此我正在使用fit_generator
将数据馈送到网络。
作为fit_generator
的第一个参数,我必须提供一个生成器,为我的模型生成数据补丁。
我使用tf.data.Dataset
来创建数据集并使用make_one_shot_iterator
并调用get_next
方法来馈送网络。
这是代码
def generator():
dataset_iterator = DatasetGenerator(...) # defined class to returns a tf iterator
with tf.Session() as sess:
next_batch = dataset_iterator.get_next()
while True:
img, label = sess.run(next_batch)
# some process on label
yield img, label
# down in the code for training:
model.fit_generator(generator=generator(), ...)
这很好用。
当我尝试将dataset_iterator
作为参数发送给generator
方法时,问题就开始了,
def generator(dataset_iterator):
with tf.Session() as sess:
next_batch = dataset_iterator.get_next()
while True:
img, label = sess.run(next_batch)
# some process on label
yield img, label
# down in the code for training:
dataset_iterator = DatasetGenerator(...)
model.fit_generator(generator=generator(dataset_iterator), ...)
现在,出现以下错误:
RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
答案 0 :(得分:0)
我找到了一种处理方法。
我发现,在tf.get_default_graph()
方法和main方法中(即在调用generator
之前)打印model.fit_generator
会返回不同的图形。
为什么?我不知道!
无论如何,我通过将默认图形作为函数的另一个参数发送并将其引入tf.Session()
来解决了该问题。像这样:
def generator(dataset_iterator, default_graph):
with tf.Session(graph=default_graph) as sess:
next_batch = dataset_iterator.get_next()
while True:
img, label = sess.run(next_batch)
# some process on label
yield img, label
# down in the code for training:
dataset_iterator = DatasetGenerator(...)
default_graph = tf.get_default_graph()
model.fit_generator(generator=generator(dataset_iterator, default_graph), ...)
我实际上不知道这是否是解决问题的最优雅的方法。非常感谢进一步的改进:)
答案 1 :(得分:-1)
它渴望执行,如果要创建一个空会话,请禁用它。