恢复张量流模型

时间:2018-05-03 14:25:38

标签: python tensorflow

我希望在经过训练后恢复张量流模型。我知道我可以使用tf.train.Saver,但问题在于恢复,因为我对get_tensor_by_name的名称感到困惑。有谁能够帮我? 这是我的图表:

x_hat = tf.placeholder(tf.float32, shape=[None, dim_img], name='input_img')
x = tf.placeholder(tf.float32, shape=[None, dim_img], name='target_img')

# dropout
keep_prob = tf.placeholder(tf.float32, name='keep_prob')

# input for PMLR
z_in = tf.placeholder(tf.float32, shape=[None, dim_z], name='latent_variable')

# network architecture
y, z, loss, neg_marginal_likelihood, KL_divergence = vae.autoencoder(x_hat, x, dim_img, dim_z, n_hidden,
                                                                                keep_prob)

1 个答案:

答案 0 :(得分:0)

当您保存模型时,您可以保存两件事:1)元图,即图表的表示(您定义的所有TF符号;以及2)包含实际变量值的检查点(按名称保存和恢复。)

还原时,您可以还原其中一个或两个组件。您所描述的是恢复元图和检查点数据。在这种情况下,你需要通过名称查找你感兴趣的各种操作和张量,这可能会令人困惑(特别是如果你没有很好地命名你的变量,你应该总是这样做。)

# In this method you import the meta graph then restore
saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
saver.restore(sess, 'my-save-dir/my-model-10000')

恢复的另一个选项(我更喜欢mysefl)是根本不加载元图。相反,只需重新运行您最初用于创建图形的相同代码(如果您做得好,这将全部组织在一个地方)。然后你只恢复检查点。这种方法的好处是您可以轻松地引用您需要的所有操作(例如成本,train_op,占位符等)。

# This method only performs the restor operation 
# assuming the graph is already constructure
saver.restore(sess, 'my-save-dir/my-model-10000')