keras load_model第二次执行时引发错误

时间:2017-08-18 12:08:08

标签: python tensorflow neural-network keras conv-neural-network

我正在创建一个网站,有时它会调用keras神经网络。 所以我有一个看起来像这样的函数:

def network(campaign):
    from keras.models import load_model
    model = load_model("sunshade/neural_network/model.h5") #the line that fail the second time i call it

    #loading some data

    label = model.predict(images, batch_size = 128, verbose = 1)

    #some unrelated code...

当我第一次执行它时,此代码工作正常,但是当我尝试第二次运行时,它会因此错误而失败:

Exception in thread Thread-31:
Traceback (most recent call last):
  File "/usr/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 930, in _run
    allow_operation=False)
  File "/usr/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 2414, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/usr/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 2493, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("Placeholder_3:0", shape=(32,), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib64/python3.4/threading.py", line 920, in _bootstrap_inner
    self.run()
  File "/usr/lib64/python3.4/threading.py", line 868, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ec2-user/SpyNet/poc/sunshadeDetector/sunshade/models.py", line 46, in launch_network
    network(self)
  File "/home/ec2-user/SpyNet/poc/sunshadeDetector/sunshade/neural_network/network.py", line 27, in network
    model = load_model("sunshade/neural_network/model.h5")
  File "/usr/local/lib64/python3.4/site-packages/keras/models.py", line 236, in load_model
    topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)
  File "/usr/local/lib64/python3.4/site-packages/keras/engine/topology.py", line 3048, in load_weights_from_hdf5_group
    K.batch_set_value(weight_value_tuples)
  File "/usr/local/lib64/python3.4/site-packages/keras/backend/tensorflow_backend.py", line 2188, in batch_set_value
    get_session().run(assign_ops, feed_dict=feed_dict)
  File "/usr/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 778, in run
    run_metadata_ptr)
  File "/usr/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 933, in _run
    + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder_3:0", shape=(32,), dtype=float32) is not an element of this graph.

顺便说一句,我使用django作为网站部分,但我不认为它是相关的。 必须有某种东西需要关闭或重新初始化...我试图使用tf.Session()和tf.reset_default_graph,但我仍然会遇到错误。 所以现在我每次想要使用这个功能时都要重启我的django服务器。

你知道吗?在最坏的情况下,我可以将模型设为单例,这样我就不必每次都重新加载...

2 个答案:

答案 0 :(得分:3)

您可以创建一个新的会话并将模型加载到其中。

from keras.models import load_model
import keras

def network(campaign):
    with keras.backend.get_session().graph.as_default():
        model = load_model("sunshade/neural_network/model.h5")
        label = model.predict(images, batch_size = 128, verbose = 1)

答案 1 :(得分:0)

当以不同的线程(即通常在Web服务中)加载图形时,我也面临此问题。

以下是我解决该问题的方法:

在加载或构建模型时,保存TensorFlow图:

graph = tf.get_default_graph()

在另一个线程中(或在异步事件处理程序中),我这样做:

global graph
with graph.as_default():
   # some predict or .....

我通过阅读以下问题来学习并解决了问题:

  1. https://github.com/Breta01/handwriting-ocr/issues/98
  2. https://github.com/keras-team/keras/issues/2397