Tensorflow 2.0 keras fit会覆盖调用函数中的“ training”参数

时间:2019-05-15 18:39:35

标签: tensorflow tensorflow2.0

使用tensorflow 2.0创建模型时,根据我如何进行前向传递,我得到两种不同的行为:

1)如果我使用model(X)进行前向传递,那很好,并且调用方法中的“训练”参数正常工作

vs。

2)如果我改用model.fit(X,y)运行模型,那么“ training”参数似乎会被覆盖并设置为None,无论其默认值为True还是False。

有人知道为什么会这样吗?例如,这意味着我无法设置模型,因此仅当训练设置为True时才会出现辍学。

(if 'id' + '(' in array: print(count))

然后按预期打印 Trueing True

!pip install tensorflow-gpu==2.0.0-alpha0
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

X = np.random.random((250, 5))
y = X[:, 0] > 0 * 1.0

class MyModel(Model):

  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, inputs, training=True):
    print("Training ", training)
    x = self.dense1(inputs)
    if training:
      x = self.dropout(x, training=training)
    return self.dense2(x)

model = MyModel()

但这会打印出无培训

model(X)    # prints out: Training True

2 个答案:

答案 0 :(得分:0)

这是一个错误:

  

当tf.keras.Model.call()的训练值变为无时   tf.keras.Model.fit()。 (tf2.0.0-alpha0)

https://github.com/tensorflow/tensorflow/issues/27275

答案 1 :(得分:0)

新发布的TF2.0中已解决此问题。

请在以上代码中将!pip install tensorflow-gpu==2.0.0-alpha0替换为!pip install tensorflow-gpu==2.0.0。谢谢!