RuntimeError:会话图为空。在调用run()之前向图形添加操作

时间:2019-01-28 15:09:10

标签: python tensorflow

我正在用keras训练神经网络,并且由于我的数据集非常大,因此我正在使用fit_generator将数据馈送到网络。 作为fit_generator的第一个参数,我必须提供一个生成器,为我的模型生成数据补丁。 我使用tf.data.Dataset来创建数据集并使用make_one_shot_iterator并调用get_next方法来馈送网络。 这是代码

def generator():
    dataset_iterator = DatasetGenerator(...)  # defined class to returns a tf iterator
    with tf.Session() as sess:
        next_batch = dataset_iterator.get_next()

        while True:
            img, label = sess.run(next_batch)
            # some process on label
            yield img, label


# down in the code for training:
model.fit_generator(generator=generator(), ...)

这很好用。 当我尝试将dataset_iterator作为参数发送给generator方法时,问题就开始了,

def generator(dataset_iterator):
    with tf.Session() as sess:
        next_batch = dataset_iterator.get_next()

        while True:
            img, label = sess.run(next_batch)
            # some process on label
            yield img, label


# down in the code for training:
dataset_iterator = DatasetGenerator(...)
model.fit_generator(generator=generator(dataset_iterator), ...)

现在,出现以下错误:

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

2 个答案:

答案 0 :(得分:0)

我找到了一种处理方法。 我发现,在tf.get_default_graph()方法和main方法中(即在调用generator之前)打印model.fit_generator会返回不同的图形。

为什么?我不知道!

无论如何,我通过将默认图形作为函数的另一个参数发送并将其引入tf.Session()来解决了该问题。像这样:

def generator(dataset_iterator, default_graph):
    with tf.Session(graph=default_graph) as sess:
        next_batch = dataset_iterator.get_next()

        while True:
            img, label = sess.run(next_batch)
            # some process on label
            yield img, label


# down in the code for training:
dataset_iterator = DatasetGenerator(...)
default_graph = tf.get_default_graph()
model.fit_generator(generator=generator(dataset_iterator, default_graph), ...)

我实际上不知道这是否是解决问题的最优雅的方法。非常感谢进一步的改进:)

答案 1 :(得分:-1)

它渴望执行,如果要创建一个空会话,请禁用它。