设置TensorFlow is_training训练标志(在batch_normalize中)

时间:2016-04-17 08:31:18

标签: python tensorflow

我想在TensorFlow中使用批量规范化,并在GitHub上遇到这个batch_normalize函数:link

我注意到有一个特定的标志来检查我们是否在训练。但是,我不熟悉如何将此标志设置为True或False,并且在训练时设置此特定标志是否标准?我指的是这里:

is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))

总结一下我的问题:如何将此标记设置为True / False以便我可以使用此batch_normalize函数?

谢谢!

2 个答案:

答案 0 :(得分:0)

该功能是scikit-flow a.k.a TF学习的一部分,而不是" base" TF - 您可以看到他们如何在图书馆的估算工具部分设置标记:GitHub link

self._graph.add_to_collection("IS_TRAINING", True)

部分self._graph是包含batchnorm操作的TF图。

答案 1 :(得分:0)

如果您将其用作传递到TensorFlowEstimator的自定义模型函数的一部分,则只需调用fit即可进行培训。当您致电predict时,batch_normalize将用于测试。

请注意,TensorFlow Learn(a.k.a Scikit Flow)会自动调用它们,因此您需要关注的是提供自定义模型函数并将其传递到TensorFlowEstimator