在tf.keras中

时间:2018-10-08 21:54:23

标签: python tensorflow keras

我正在尝试在tf.keras中实现DARTS: Differentiable Architecture Search。该方法相对简单,但是至关重要的是,不同的图层集会交替分批更新。 非常天真的方法是通过回调使用不同的可训练层重新编译模型:

class SwitchTrainableLayers(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs={}):
        for layer in self.model.layers:
            layer.trainable = not layer.trainable
        self.model.compile(...)

正如人们可能期望的那样,这是非常低效的。我知道也可以通过使用两个不同的更新操作显式写出整个训练循环来完成此操作,但是我的其余模型都使用fit_generator,所以我希望找到一种不需要太多的解决方法TF样板。

另一种选择是创建具有不同可训练层集的两个模型,交替地将一个批次添加到一个模型中,然后将更新的重量复制到另一个模型中。

也许可以通过将Optimizer子类化以为每个批次的不同层提供更新来做到这一点吗?还有什么其他优雅的解决方案可以让我在培训/实验代码中保持在keras抽象水平上?

0 个答案:

没有答案