我想在同一个TensorFlow会话中训练和测试我的模型。我使用两个不同的tf.FIFOQueue
来使用多个线程加载训练和测试数据(因为feed_dict
会导致性能不佳)。我尝试了两件事:
我尝试使用共享参数创建我的模型两次(用于训练和测试)。但我使用tf.contrib.layers.batch_norm
并且不允许共享批量标准化的参数。
我尝试使用tf.FIFOQueue
在is_training
布尔占位符上调整网络的输入tf.cond
,但显然tf.cond
执行tf.FIFOQueue
s出列功能,无论is_training
持有什么。
我想知道在不使用feed_dict
的情况下,在同一会话中训练和测试的常规设置是什么。
答案 0 :(得分:1)
显然tf.contrib.layers.batch_norm
允许共享批量规范化参数(如果在全局tf.variable_scope
中定义。
示例代码:取自here。
def model(data, is_training=False, reuse=None, scope='my_model'):
# Define a variable scope to contain all the variables of your model
with tf.variable_scope(scope, 'model', data, reuse=reuse):
....
net = tf.contrib.layers.batch_norm(net, is_training)
return net
train_outputs = model(train_data, is_training=True)
eval_outputs = model(eval_data, is_training=False, reuse=True)