具有张量流会话的函数范围

时间:2017-06-20 18:29:56

标签: python tensorflow

我正在从检查点文件恢复Tensorflow图,然后尝试使用加载的模型运行几个不同的前馈计算。

为了保持代码模块化,我想将前馈计算放在不同的函数中。传递Tensorflow会话的正确方法是什么?现在,我的结构如下:

def setup_graph(ckpt_file):
    sess = tf.Session():
    saver = tf.train.import_meta_graph(ckpt_file)
    saver.restore(sess, ckpt_file)
    graph = tf.get_default_graph()
    # [ Here, get some placeholder tensors in the graph and save into placeholders dict ]
    return sess, graph, placeholders

# Function 1 using the restored model
def predict(sess, graph, placeholders, input_sound):
    # Use the graph being passed around to get other tensors we need
    pred = graph.get_tensor_by_name("pred:0")
    # Use the session being passed around to evaluate the tensors we need
    prediction = sess.run([pred], feed_dict={placeholders["input"]: input_sound})

 # Function 2 using the restored model
 def get_embedding(sess, graph, placeholders, input_sound):
     embed = graph.get_tensor_by_name("embed:0")
     embedding = sess.run([embed], feed_dict={placeholders["input"]: input_sound})

在函数之间传递一个开放的tensorflow会话感觉有些不对,但我不确定如何保持模块化。有没有标准的方法来做到这一点?

0 个答案:

没有答案