我正在从检查点文件恢复Tensorflow图,然后尝试使用加载的模型运行几个不同的前馈计算。
为了保持代码模块化,我想将前馈计算放在不同的函数中。传递Tensorflow会话的正确方法是什么?现在,我的结构如下:
def setup_graph(ckpt_file):
sess = tf.Session():
saver = tf.train.import_meta_graph(ckpt_file)
saver.restore(sess, ckpt_file)
graph = tf.get_default_graph()
# [ Here, get some placeholder tensors in the graph and save into placeholders dict ]
return sess, graph, placeholders
# Function 1 using the restored model
def predict(sess, graph, placeholders, input_sound):
# Use the graph being passed around to get other tensors we need
pred = graph.get_tensor_by_name("pred:0")
# Use the session being passed around to evaluate the tensors we need
prediction = sess.run([pred], feed_dict={placeholders["input"]: input_sound})
# Function 2 using the restored model
def get_embedding(sess, graph, placeholders, input_sound):
embed = graph.get_tensor_by_name("embed:0")
embedding = sess.run([embed], feed_dict={placeholders["input"]: input_sound})
在函数之间传递一个开放的tensorflow会话感觉有些不对,但我不确定如何保持模块化。有没有标准的方法来做到这一点?