Keras和AutoGraph

时间:2019-06-19 13:22:35

标签: python tensorflow keras tensorflow2.0

阅读thisthis answer我了解到,对于TensorFlow-2.0上的非动态模型,Keras将使用AutoGraph。但是现在编写回调以获取训练期间变量的历史记录,

class TrainHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.vars = []

    def on_batch_end(self, batch, logs={}):
        self.vars.append([v.numpy() for v in self.model.variables])

我看到可以使用.numpy()的急张量方法。我期待有一个类似numpy() is only available when eager execution is enabled的错误。将Keras与TensorFlow-2.0结合使用时,是否有一个正在执行的热切代码?

Here提到对于像卷积这样的计算密集型函数,与渴望的代码相比,装饰tf.function的函数的速度没有太大提高,但是从示例中显示的数字来看,我想这是有区别的对于长期培训可能具有重要意义。在用GradientTape装饰的自定义训练函数上使用tf.function代替Keras的fit()方法会更好吗?

1 个答案:

答案 0 :(得分:0)

如果您打算在外部循环(即“时代”循环)上使用@tf.function,则可能不会给您的模型带来太多好处。这只会使开发更加困难。代码越多,复杂性就越大。

但是,您必须在自定义损失函数和其他每批调用一次的函数上绝对使用tf.function

而且,不,使用GradientTape -d自定义tf.function可能不会超过keras的fit方法。这些年来,它经过了彻底的测试和完善。

这能回答您的问题吗?