Tensorflow:在单独的功能中训练和测试

时间:2016-08-04 13:48:44

标签: tensorflow

我试图在两个单独的函数中使用Tensorflow模型:一个用于训练它,另一个用于测试它。例如,训练函数看起来像这样:

graph = tf.Graph()
with graph.as_default():
    tf_dataset = tf.placeholder(tf.float32, shape=(None, num_dims))
    ...
    weights = tf.Variable(tf.truncated_normal([num_dims, num_labels]))
    ...
    optimizer =    tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    prediction = tf.nn.softmax(logits)
    ...
    session = tf.Session(graph=graph)
    ...

另一个评估函数只使用prediction和测试数据,如下所示:

session.run(prediction, feed_dict={tf_dataset: test_data})

当然,问题是tf_dataset不在另一个函数的范围内。我可以从培训函数返回sessionprediction,但是必须与评估代码共享每个占位符似乎有点蹩脚。

有没有办法从会话或图表中以某种方式获取引用?另外,有没有关于如何在Tensorflow中分离培训和评估代码的良好实践?

1 个答案:

答案 0 :(得分:2)

您可以为占位符指定唯一名称并使用它。 IE,

tf_dataset = tf.placeholder(tf.float32, shape=(None, num_dims), name="datainput")
...
sess.run(..., feed_dict={"datainput:0": mydata})

您还可以获取图表中所有操作的名称/类型对,这样就可以恢复所有占位符张量名称

[(op.name+":0", op.op_def.name) for op in graph.get_operations()]