我正在使用tf.layers.batch_normalization
API训练模型。在训练之后,我需要加载训练好的模型来对新数据进行预测。加载权重有两种方法,如下所示:
(1):
saver1 = tf.train.Saver(tf.global_variables(), max_to_keep=10)
saver1.restore(sess, '{}'.format(args.restore_ckpt))
(2):
saver2 = tf.train.import_meta_graph('{}.meta'.format(args.restore_ckpt))
saver2.restore(sess, '{}'.format(args.restore_ckpt))
我发现(1)可以产生很高的预测准确度(比如97%
),但是(2)的准确度要低得多(比如59%
)。这是否意味着(2)没有正确加载批量标准化层的权重?期待您的评论!
更新:
我发现,(1)加载的模型具有相同的预测精度,无论测试batch_size是什么。我尝试了批量大小为1和16,两种结果都有97%的准确度。
似乎加载了(2)的权重,我需要添加以下代码:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer_op.minimize(loss_total_op, global_step=global_step)
然后它会产生高精度97%
。也许我错误地理解批量标准化。