使用tf.function的Tensorflow 2.0模型非常慢,并且每次火车数量变化时都会重新编译。渴望的速度快大约4倍

时间:2019-04-16 14:55:32

标签: keras tensorflow2.0

我有未编译的keras代码构建的模型,并试图通过自定义训练循环运行它们。

TF 2.0急切(默认)代码在CPU(笔记本电脑)上运行大约30秒。当我用包装的tf.function调用方法创建一个keras模型时,它运行的速度非常慢,而且启动时间似乎很长,尤其是“第一次”。

例如,在tf.function代码中,对10个样本的初始训练需要40秒,而对10个样本的后续训练则需要2秒。

在20个样本上,初始样本需要50秒,后续样本需要4秒。

对一个样本进行的第一轮训练需要2s,而后续动作则需要200 ms。

因此,每次火车调用似乎都在创建一个新图,其中复杂度随火车数量成比例!?

我只是在做这样的事情:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d

其中的模型是keras.model.Model,其示例使用@ tf.function装饰call方法。

1 个答案:

答案 0 :(得分:6)

我在Using a Python native type处分析了@tf.function的这种行为。

简而言之:tf.function的设计不会自动将Python本机类​​型装箱到具有明确定义的tf.Tensor的{​​{1}}对象。

如果您的函数接受dtype对象,则在第一次调用时将分析该函数,然后将创建图形并将其与该函数关联。在每个非首次调用中,如果tf.Tensor对象的dtype匹配,则图被重用。

但是如果使用Python本机类​​型,则每次使用不同的值调用该函数时都会构建graphg

简而言之:如果计划使用tf.Tensor,则设计代码以在各处使用tf.Tensor而不是Python变量。

@tf.function并不是神奇地加速在急切模式下运行良好的功能的包装器;是一个包装程序,需要设计eager函数(正文,输入参数,dytpes)以了解创建图形后将发生的情况,从而获得真正的加速效果。