恢复Tensorflow模型而不从目录中提取

时间:2018-10-09 19:19:00

标签: python tensorflow model neural-network

我目前正在使用Tensorflow的Saver类保存和恢复神经网络模型,如下所示:

saver.save(sess, checkpoint_prefix, global_step=step)

saver.restore(sess, checkpoint_file)

这会将.ckpt个模型文件保存到指定路径。由于我正在运行多个实验,因此保存这些模型的空间有限。

我想知道是否有一种方法可以在不将内容保存到指定目录的情况下保存这些模型。

例如我可以只在最后一个检查点将某个对象传递给某些validate()函数,然后从该对象恢复模型吗?

据我所知,save_path中的tf.train.Saver.restore()参数不是可选的。

任何见识将不胜感激。

谢谢

1 个答案:

答案 0 :(得分:1)

您可以使用加载的图形和权重以与训练相同的方式进行评估。您只需要将输入更改为评估集中的输入即可。这是一个训练循环的伪代码,每个1000迭代都有一个评估循环(假设您已经创建了一个名为tf.Session的{​​{1}})

sess

如果您使用x = tf.placeholder(...) loss, train_step = model(x) for i in range(num_step): input_x = get_train_data(i) sess.run(train_step, feed_dict={x: input_x}) if i % 1000 == 0 and i != 0: eval_loss = 0 for j in range(num_eval): input_x = get_eval_data(j) eval_loss += sess.run(loss, feed_dict={x: input_x}) print(eval_loss/num_eval) 作为输入,则只需创建一个tf.data即可选择要使用的输入:

tf.cond

is_training = tf.placeholder(tf.bool) next_element = tf.cond(is_training, lambda: get_next_train(), lambda: get_next_eval()) get_next_train必须创建用于读取数据集的所有操作,否则运行上述代码会有副作用。

这样,如果您不想的话,就不必保存任何内容。