使用具有张量流的BatchNorm层训练Keras模型

时间:2017-04-01 11:58:04

标签: python tensorflow keras batch-normalization

我使用 keras 构建模型,并在 tensorflow 中编写优化代码和所有其他代码。当我使用非常简单的层,如密集 Conv2D 时,一切都很简单。但是将 BatchNormalization 图层添加到我的keras模型中会使问题变得复杂。

由于 BatchNormalization 图层在训练阶段和测试阶段的表现不同,我发现我的 feed_dict中需要 K.learning_phase():True 。但是下面的代码效果不好。它运行时没有错误,但模型的性能没有任何改善。

import keras.backend as K
...
x_train, y_train = get_data()
sess.run(train_op, feed_dict={x:x_train, y:y_train, K.learning_phase():True})

当我尝试使用keras fit 功能训练keras模型时,效果很好。

如何在 tensorflow 中使用 BatchNormalization 图层训练 keras 模型?

1 个答案:

答案 0 :(得分:1)

实际上我复制了这个我没见过的问题。

我找到了答案here,它只是将一个特殊参数传递给BatchNormalization图层调用