我对机器学习和Tensorflow框架比较陌生。我试图使用MNIST手写数字数据集,对受here显示的代码影响很大的受过训练的模型,并对我创建的测试示例进行推断。但是,我正在使用GPU的远程计算机上进行培训,并尝试将数据保存到目录中,以便可以在本地计算机上传输数据和推断
似乎我可以用tf.saved_model.simple_save
保存一些模型,但是,我不确定如何使用保存的数据进行推断以及如何在给定新数据的情况下使用数据进行预测图片。似乎有多种方法可以保存模型,但是我不确定使用Tensorflow框架进行哪种惯例或“正确方式”。
到目前为止,这是我认为我需要的行,但是不确定是否正确。
tf.saved_model.simple_save(sess, 'mnist_model',
inputs={'x': self.x},
outputs={'y_': self.y_, 'y_conv':self.y_conv})
如果有人可以指出如何正确保存经过训练的模型以及可以使用哪些变量来推断保存的模型的方向,我将非常感激。
答案 0 :(得分:1)
执行此操作的一种方法是在图形定义中创建tf.train.Saver()
对象,然后使用该对象将网络保存到指定目录。然后可以将该目录中的权重从远程计算机下载到本地,然后在本地还原。这是一个小的示例网络:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# >>>> Config. Vars <<<<
TRAIN_STEPS = 1000
SAVE_EVERY = 100
# >>>> Network <<<<
inputs = tf.placeholder(tf.float32, shape=[None, 784])
labels = tf.placeholder(tf.float32, shape=[None, 10])
h1 = tf.layers.dense(inputs, 256, activation=tf.nn.relu, use_bias=True)
logits = tf.layers.dense(h1, 10, use_bias=True)
predictions = tf.nn.softmax(logits)
prediction_ids = tf.argmax(predictions, axis=1)
# >>>> Loss & Optimisation <<<<
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
opt = tf.train.AdamOptimizer().minimize(loss)
# >>>> Utilities <<<<
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
# >>>> Training - run on remote, comment out locally <<<<
for i in range(TRAIN_STEPS):
print("Train step {}".format(i), end="\r")
batch_data, batch_labels = mnist.train.next_batch(batch_size=128)
feed_dict = {
inputs: batch_data,
labels: batch_labels
}
l, _ = sess.run([loss, opt], feed_dict=feed_dict)
if i % SAVE_EVERY == 0:
saver.save(sess, "saved_model/network_weights.ckpt")
# >>>> Using the network - run locally to use the network <<<
saver.restore(sess, "saved_model/network_weights.ckpt")
test_data, test_labels = mnist.test.images, mnist.test.labels
feed_dict = {
inputs: test_data,
labels: test_labels
}
preds = sess.run(prediction_ids, feed_dict=feed_dict)
print(preds)
因此,一旦您在网络中定义了保护程序,就可以使用它将权重保存到指定的目录-在这种情况下,将保存在目录“ saved_models”中,在运行此特定代码之前需要先创建该目录
恢复模型就像调用saver.restore()
一样简单,然后将会话和权重存储的路径传递给模型。因此,您可以在远程计算机上运行此代码,将“ saved_models”目录下载到本地计算机,然后在注释掉训练部分以实际使用模型的情况下运行此代码。