我想在TensorFlow中使用批量规范化,并在GitHub上遇到这个batch_normalize
函数:link
我注意到有一个特定的标志来检查我们是否在训练。但是,我不熟悉如何将此标志设置为True或False,并且在训练时设置此特定标志是否标准?我指的是这里:
is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
总结一下我的问题:如何将此标记设置为True / False以便我可以使用此batch_normalize
函数?
谢谢!
答案 0 :(得分:0)
该功能是scikit-flow a.k.a TF学习的一部分,而不是" base" TF - 您可以看到他们如何在图书馆的估算工具部分设置标记:GitHub link
self._graph.add_to_collection("IS_TRAINING", True)
部分self._graph
是包含batchnorm操作的TF图。
答案 1 :(得分:0)
如果您将其用作传递到TensorFlowEstimator
的自定义模型函数的一部分,则只需调用fit
即可进行培训。当您致电predict
时,batch_normalize
将用于测试。
请注意,TensorFlow Learn(a.k.a Scikit Flow)会自动调用它们,因此您需要关注的是提供自定义模型函数并将其传递到TensorFlowEstimator
。