背景:
我有一些复杂的强化学习算法,我想在多个线程中运行。
问题
尝试在线程中调用sess.run
时,出现以下错误消息:
RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
代码重现该错误:
import tensorflow as tf
import threading
def thread_function(sess, i):
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})
def main(sess):
thread_list = []
for i in range(0, 4):
t = threading.Thread(target=thread_function, args=(sess, i))
thread_list.append(t)
t.start()
for t in thread_list:
t.join()
if __name__ == '__main__':
sess = tf.Session()
main(sess)
如果我在线程外运行相同的代码,它将正常工作。
有人可以提供一些关于如何在python线程中正确使用Tensorflow会话的见解吗?
答案 0 :(得分:2)
不仅Session可以是当前线程的默认值,还可以是图形。
当您传递会话并在其上调用run
时,默认图形将是另一个图形。
您可以这样修改 thread_function 使其起作用:
def thread_function(sess, i):
with sess.graph.as_default():
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})
但是,我不希望有任何明显的提速。 Python线程在某些其他语言中并不意味着什么,只有某些操作(例如io)可以并行运行。对于CPU繁重的操作,它不是很有用。多处理可以真正地并行运行代码,但您不会共享同一会话。
答案 1 :(得分:2)
在github上用另一个资源扩展了de1的答案: tensorflow/tensorflow#28287 (comment)
以下为我解决了tf的多线程兼容性:
# on thread 1
session = tf.Session(graph=tf.Graph())
with session.graph.as_default():
k.backend.set_session(session)
model = k.models.load_model(filepath)
# on thread 2
with session.graph.as_default():
k.backend.set_session(session)
model.predict(x)
这将保留其他线程的Session
和Graph
。
该模型在其“上下文”(而不是默认上下文)中加载,并保留供其他线程使用。
(默认情况下,模型会加载为默认的Session
和默认的Graph
)
另一个优点是它们被保存在同一对象中-更易于处理。