我试图在两个单独的函数中使用Tensorflow模型:一个用于训练它,另一个用于测试它。例如,训练函数看起来像这样:
graph = tf.Graph()
with graph.as_default():
tf_dataset = tf.placeholder(tf.float32, shape=(None, num_dims))
...
weights = tf.Variable(tf.truncated_normal([num_dims, num_labels]))
...
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
prediction = tf.nn.softmax(logits)
...
session = tf.Session(graph=graph)
...
另一个评估函数只使用prediction
和测试数据,如下所示:
session.run(prediction, feed_dict={tf_dataset: test_data})
当然,问题是tf_dataset
不在另一个函数的范围内。我可以从培训函数返回session
和prediction
,但是必须与评估代码共享每个占位符似乎有点蹩脚。
有没有办法从会话或图表中以某种方式获取引用?另外,有没有关于如何在Tensorflow中分离培训和评估代码的良好实践?
答案 0 :(得分:2)
您可以为占位符指定唯一名称并使用它。 IE,
tf_dataset = tf.placeholder(tf.float32, shape=(None, num_dims), name="datainput")
...
sess.run(..., feed_dict={"datainput:0": mydata})
您还可以获取图表中所有操作的名称/类型对,这样就可以恢复所有占位符张量名称
[(op.name+":0", op.op_def.name) for op in graph.get_operations()]