在不使用培训代码的情况下恢复TF Eager模型

时间:2018-05-13 17:40:15

标签: python tensorflow

我正在以急切模式训练(并保存)一个非常简单的模型,如下所示:

import os
import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

NUM_EXAMPLES = 2000

training_inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
outputs = training_inputs * 3 + 2 + noise


class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.W = tfe.Variable(5., name="weight")
        self.b = tfe.Variable(0., name="bias")

    def predict(self, input):
        return self.W * input + self.b


def loss(model, inputs, outputs):
    error = model.predict(inputs) - outputs
    return tf.reduce_mean(tf.square(error))


def grad(model, inputs, outputs):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, outputs)
    return tape.gradient(loss_value, [model.W, model.b])


if __name__ == "__main__":
    model = Model()
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

    for i in range(300):
        gradients = grad(model, training_inputs, outputs)
        optimizer.apply_gradients(zip(gradients, [model.W, model.b]),
                                  global_step=tf.train.get_or_create_global_step())

    checkpoint_dir = './checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

    root = tfe.Checkpoint(optimizer=optimizer,
                          model=model,
                          optimizer_step=tf.train.get_or_create_global_step())
    root.save(file_prefix=checkpoint_prefix)

我发现保存/恢复的唯一方法(使用CheckpointSaver)意味着可以访问Model类以将其加载到其他位置,例如:

model = Model()
checkpointer = tfe.Checkpoint(model=model)
checkpointer.restore(tf.train.latest_checkpoint('checkpoints/'))
print(model.predict(7))

来自save的{​​{1}}方法似乎尚未针对Eager模式实施:

tf.keras.Model

是否有其他方法可以保存和加载模型而无需实例化新的model.save("keras_model") >>> NotImplementedError 对象?

0 个答案:

没有答案