我正在尝试使用TensorFlow中的模型进行保存,还原和预测。已经有很多答案了,但是没有一个是专门针对生产中的问题的,所以我相信这个问题将对寻求有关该主题的实践方法的人们有所帮助。
因此,我训练了this jupyter code上可用的模型,现在我试图保存模型并在不同的图像上进行预测。我能够运行它和got this result after training。
然后,我保存了模型,将其加载并在同一张图像上再次进行了预测。 This is the result after testing on the restored model。显然我做错了,要么保存模型,要么加载模型。
我用于保存模型的代码如下:
在def run():
...
with tf.Session() as session:
# Returns the three layers, keep probability and input layer from the vgg architecture
image_input, keep_prob, layer3, layer4, layer7 = load_vgg(session, VGG_PATH)
# The resulting network architecture, adding a decoder on top of the given vgg model
model_output = layers(layer3, layer4, layer7, NUMBER_OF_CLASSES)
logits, train_op, cross_entropy_loss = optimize(model_output, correct_label, learning_rate, NUMBER_OF_CLASSES)
# Create saver
saver = tf.train.Saver()
# Initilize all variables
session.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# train the neural network
train_nn(session, EPOCHS, BATCH_SIZE, get_batches_fn,
train_op, cross_entropy_loss, image_input,
correct_label, keep_prob, learning_rate)
# Save inference data
helper.save_inference_samples(RUNS_DIRECTORY, DATA_DIRECTORY, session, IMAGE_SHAPE, logits, keep_prob, image_input)
# Save model
saver.save(session, "./fcn_liquid_model", global_step = 500)
然后,为了加载保存的模型,我尝试执行以下操作:
tf.reset_default_graph:
with tf.Session() as session:
# Restore variables and model
saver = tf.train.import_meta_graph("./fcn_liquid_model-500.meta")
saver.restore(session, tf.train.latest_checkpoint("./"))
print("Model restored.")
# Returns the three layers, keep probability and input layer from the vgg architecture
image_input, keep_prob, layer3, layer4, layer7 = load_vgg(session, VGG_PATH)
# The resulting network architecture, adding a decoder on top of the given vgg model
model_output = layers(layer3, layer4, layer7, NUMBER_OF_CLASSES)
logits = tf.reshape(model_output, (-1, NUMBER_OF_CLASSES))
# Initilize all variables
session.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
helper.save_inference_samples(RUNS_DIRECTORY, DATA_DIRECTORY, session, IMAGE_SHAPE, logits, keep_prob, image_input)
得到前面显示的结果。有谁知道发生了什么或如何解决该问题?谢谢。