使用model.fit时如何将'training'参数传递给tf,keras.Model

时间:2019-07-18 07:01:19

标签: python tensorflow keras

所以我有这个由子类API编写的模型,调用签名看起来像call(x,training),在执行batchnorm和dropout时需要训练参数来区分训练和非训练。当我使用model.fit时,如何让模型知道自己处于训练模式或评估模式?

谢谢!

2 个答案:

答案 0 :(得分:1)

据我所知,对此没有论据。 Model.fit只是简单地根据提供的任何训练数据训练模型,并且在每个时期结束时,对提供的验证数据或通过使用validate_split评估训练。

答案 1 :(得分:0)

实际上,在文档https://www.tensorflow.org/beta/guide/keras/custom_layers_and_models中,它说:“某些层,尤其是BatchNormalization层和Dropout层,在训练和推理期间具有不同的行为。对于此类层,公开训练是标准做法调用方法中的(布尔)参数。

通过在调用中公开此参数,可以启用内置的训练和评估循环(例如,拟合)以在训练和推理中正确使用该层。”因此,我认为训练参数由keras自动传递。删除训练参数的默认值并且不会引发任何错误,因此很可能keras内置循环会执行此操作。