我正在以急切模式训练(并保存)一个非常简单的模型,如下所示:
import os
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
NUM_EXAMPLES = 2000
training_inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
outputs = training_inputs * 3 + 2 + noise
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.W = tfe.Variable(5., name="weight")
self.b = tfe.Variable(0., name="bias")
def predict(self, input):
return self.W * input + self.b
def loss(model, inputs, outputs):
error = model.predict(inputs) - outputs
return tf.reduce_mean(tf.square(error))
def grad(model, inputs, outputs):
with tf.GradientTape() as tape:
loss_value = loss(model, inputs, outputs)
return tape.gradient(loss_value, [model.W, model.b])
if __name__ == "__main__":
model = Model()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
for i in range(300):
gradients = grad(model, training_inputs, outputs)
optimizer.apply_gradients(zip(gradients, [model.W, model.b]),
global_step=tf.train.get_or_create_global_step())
checkpoint_dir = './checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer,
model=model,
optimizer_step=tf.train.get_or_create_global_step())
root.save(file_prefix=checkpoint_prefix)
我发现保存/恢复的唯一方法(使用Checkpoint
或Saver
)意味着可以访问Model
类以将其加载到其他位置,例如:
model = Model()
checkpointer = tfe.Checkpoint(model=model)
checkpointer.restore(tf.train.latest_checkpoint('checkpoints/'))
print(model.predict(7))
来自save
的{{1}}方法似乎尚未针对Eager模式实施:
tf.keras.Model
是否有其他方法可以保存和加载模型而无需实例化新的model.save("keras_model")
>>> NotImplementedError
对象?