我正在尝试保存Tensorflow模型并重复使用它。
为了清楚地理解这个问题,我创建了一个包含10个元素的二进制数据集,并重复对这10个元素进行训练,而每100次迭代保存模型。 然后在同一组上运行测试。理想情况下,我希望测试在保存模型时产生相同的成本。 但是,我可能会错过一些东西并加载一个训练有素的模型并没有给出预期的成本价值:
def model(X, w1, w2, w3, w4, wo, p_keep_conv, p_keep_hidden):
l1 = tf.nn.relu(tf.nn.conv2d(X, w1, strides=[1, 1, 1, 1], padding='SAME'))
l1 = tf.nn.max_pool(l1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
l1 = tf.nn.dropout(l1, p_keep_conv)
# ... other layer def.s
l4 = tf.nn.relu(tf.matmul(l3, w4))
l4 = tf.nn.dropout(l4, p_keep_hidden)
return tf.matmul(l4, wo, name="pyx")
X = tf.placeholder("float", [None, size1, size2, size3], name="X")
Y = tf.placeholder("float", [None, 1], name="Y")
py_x = model(X, wo, p_keep_conv, p_keep_hidden)
cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=py_x, targets=Y, pos_weight=POS_WEIGHT))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
with tf.Session() as sess:
batch_x, batch_y = read_file('train.dat', 10)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
for step in range(NUM_TRAIN_BATCHES):
x, y = sess.run([batch_x, batch_y])
_, costval = sess.run([train_op, cost], feed_dict={X: x, Y: y, p_keep_conv: 0.8, p_keep_hidden: 0.5})
if step % 100 == 0
print("Step %d, cost %1.5f" % (step, cost_value))
saver.save(sess, './train.model', global_step=step)
以上代码打印如下:
Step 0, cost 1.10902
Step 100, cost 0.83170
Step 200, cost 0.00003
Step 300, cost 0.00000
现在,如果我加载在第300次迭代期间保存的模型并尝试应用于相同的数据:
model_no = 300
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./train.model-%d.meta' % (model_no))
saver.restore(sess, tf.train.latest_checkpoint('./'))
batch_x, batch_y = read_file('train.dat', 10)
sess.run(tf.global_variables_initializer())
x, y = sess.run([batch_x, batch_y])
cost_value = sess.run(cost, feed_dict={"X:0": x, "Y:0": y, p_keep_conv: 0.8, p_keep_hidden: 0.5})
print("cost %1.5f" % (cost_value))
以上打印:
cost loss 1.10895
对模型训练的第一次迭代非常接近。
另一件我无法理解的事情是检查点文件,其中仅包含以下内容:
model_checkpoint_path: "train.model-300"
all_model_checkpoint_paths: "train.model-0"
all_model_checkpoint_paths: "train.model-100"
all_model_checkpoint_paths: "train.model-200"
all_model_checkpoint_paths: "train.model-300"
如果检查点只包含模型文件的路径并明确加载特定模型,那么它如何帮助以及调用saver.restore(sess, tf.train.latest_checkpoint('./'))
背后的想法是什么?
答案 0 :(得分:0)
你的问题在这里:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./train.model-%d.meta' % (model_no))
saver.restore(sess, tf.train.latest_checkpoint('./'))
batch_x, batch_y = read_file('train.dat', 10)
sess.run(tf.global_variables_initializer()) # <------
您正在重新初始化所有变量,这意味着您将再次使用随机权重覆盖已加载的权重。如果您加载某些内容或从头开始,或先初始化然后加载,请先检查。或者,使用像TF的主管那样为你做记账。