我正在尝试将传递给我的一些TensorFlow代码集成到Django应用程序中。
如果我做一个标准的Django python manage runserver
,第一个预测(在第一次启动服务器时)起作用,而每个其他预测都会产生异常 - 如下所示。
如果我执行以下操作,我用Google搜索,python manage runserver --nothreading
我不再获得例外。但是,我以前从来没有这么做过,我担心它的影响,特别是当我没有这个选项时,一旦转移到生产部署。
我对TensorFlow不太了解,但我最终想出了第一个预测之后的第二个解决方案,使用graph = tf.get_default_graph()
然后使用with graph.as_default()
进行每次预测...但我不知道真的知道这完全是什么,或者它是否是一个好主意。
异常
Traceback (most recent call last):
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/tensorflow/python/client/session.py", line 1075, in _run
subfeed, allow_tensor=True, allow_operation=False)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3590, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3669, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("Placeholder:0", shape=(363, 27), dtype=float32) is not an element of this graph.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/core/handlers/exception.py", line 35, in inner
response = get_response(request)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/core/handlers/base.py", line 128, in _get_response
response = self.process_exception_by_middleware(e, request)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/core/handlers/base.py", line 126, in _get_response
response = wrapped_callback(request, *callback_args, **callback_kwargs)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/views/generic/base.py", line 69, in view
return self.dispatch(request, *args, **kwargs)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/utils/decorators.py", line 62, in _wrapper
return bound_func(*args, **kwargs)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/contrib/auth/decorators.py", line 21, in _wrapped_view
return view_func(request, *args, **kwargs)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/django/utils/decorators.py", line 58, in bound_func
return func.__get__(self, type(self))(*args2, **kwargs2)
File "/home/pembo13/virt/myproject/myproject/myproject/frontend/views.py", line 84, in dispatch
return self.handle(request, self.client)
File "/home/pembo13/virt/myproject/myproject/myproject/frontend/views.py", line 196, in handle
os.path.join(settings.BASE_DIR, 'ai', 'model', 'categoryModel.pkl') # modelPath
File "/home/pembo13/virt/myproject/myproject/ai/__init__.py", line 50, in predictCategory
model = modeler.loadModel(modelPath)
File "/home/pembo13/virt/myproject/myproject/ai/./modules/modeler.py", line 36, in loadModel
model.set_weights(modelPkl['modelWeights'])
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/keras/engine/network.py", line 515, in set_weights
K.batch_set_value(tuples)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2435, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/home/pembo13/virt/myproject/lib64/python3.6/site-packages/tensorflow/python/client/session.py", line 1078, in _run
'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(363, 27), dtype=float32) is not an element of this graph.
答案 0 :(得分:0)
在加载或构建模型后,立即保存TensorFlow图:
tf.keras.backend.clear_session()
graph = tf.get_default_graph()
global graph
with graph.as_default():
(... do the prediction here ...)
似乎对此问题进行了长时间的讨论here