Tensorflow是否支持Keras模型fit()方法和急切执行?

时间:2018-04-18 00:31:10

标签: tensorflow keras

我正在训练Keras模型(tf.keras.models.Sequential)调用其方法fit()

由于我启用了急切执行,因此训练时间(相同数量的时期)从20.1s上升到49.4s。此外,培训似乎不再收敛,因为损失保持在9左右(没有急切执行,它下降到1),而方法fit()甚至没有报告所请求的指标"精度"了。

对Keras型号的热切执行支持?请注意,我在模型上调用方法fit(),而不是使用估算器。

这里是声明模型并进行培训的代码片段。使用TF 1.7安装pip3的GPU。

tf.enable_eager_execution()

model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(11,)) ,
    tf.keras.layers.Dense(64, activation='relu') ,
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(11, activation='softmax')
])

optimizer = tf.train.AdamOptimizer()
# optimizer = 'adam'
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(x=train_X, y=train_y, epochs=200, batch_size=64, verbose=2)

更新:在Tensorflow GITHUB上提交了#18642号问题。

3 个答案:

答案 0 :(得分:1)

以下是Tensorflow网站here

的引用
  

计算导数时,有许多参数需要优化。当结构化为可重用的类和对象而不是单个顶级函数时,TensorFlow代码更易于阅读。热切的执行鼓励在tf.keras.layers模块中使用Keras样式的图层类。此外,tf.train.Optimizer类提供了复杂的技术来计算参数更新。

这意味着使用Eager执行允许keras层和后续模型。 至于你的时间安排,这个链接还提到了如何使用渴望停止构建图表。

  

TensorFlow的热切执行是一个必要的编程环境,它可以立即评估操作,而无需额外的图形构建步骤。操作返回具体值,而不是构建计算图以便稍后运行。

考虑到您拥有的DENSE图层数量,这可能会使您的模型更难运行。有人可能会纠正我,因为我之前没有做过很多关于DENSE层的工作,或者说我已经很久了。如果这不起作用,那么我会调查你的损失函数。 This answer may help if that becomes a problem

但其他一切看起来都不错。希望这会有所帮助。

修改

好的,我看到你说的是命运。是的,第一个链接使用顺序模型,但梯度磁带渐变得体。深入阅读热切的教程表明他们也只使用Gradient磁带。以下是教程中关于培训的内容:

  

自动微分对于实现机器学习算法(例如用于训练神经网络的反向传播)非常有用。在急切执行期间,使用tfe.GradientTape跟踪稍后计算梯度的操作.tfe.GradientTape是一种选择加入功能,可在不跟踪时提供最大性能。由于在每次呼叫期间可能发生不同的操作,因此所有前向传递操作都被记录到" tape"。要计算渐变,请向后播放磁带然后丢弃。特定的tfe.GradientTape只能计算一次,后续调用会产生运行时错误。

所以也许就像现在一样,只有Gradient磁带和估算器方法是你应该用的渴望。

答案 1 :(得分:1)

我在tensorflow上报道的问题得到了答案:

  

感谢您提供错误报告。我们已经解决了这个问题   很快就会出现在GitHub上。

请参阅GITHUB for Tensorflow上的问题#18642。

基于此,我了解一旦错误得到修复,Keras模型的方法fit()将得到热切执行的支持。

答案 2 :(得分:0)

在阅读 Model (documentation) 上的 compile 方法时,可以找到一个参数,run_eagerly

<块引用>

run_eagerly:布尔值。默认为假。如果为 True,则此模型的逻辑将不会包含在 tf.function 中。除非您的模型无法在 tf.function 中运行,否则建议将此设置为 None。

因此,默认情况下,tf.keras.Model 将默认通过图执行运行,而不是急切执行。