如何使用tf.Estimator忽略不在检查点中的变量?

时间:2017-10-10 14:41:48

标签: tensorflow

我使用tf.Estimator来训练和评估我的模型。在评估期间,我想使用张量板投影仪可视化。为此,我需要使用我想要可视化的功能创建和填充变量。我的model_fn如下所示:

def model_fn(...):
  ....
  predictions = net(features, is_training=is_training)
  ...

  if mode == ModeKeys.EVAL:
    embedding_var = tf.get_variable("feature_embedding", ...)
    update_embedding = embedding_var.assign(predictions)
  ....     

问题是embedding_var仅存在于评估图中。这会导致以下错误

NotFoundError (see above for traceback): Key feature_embedding not found in checkpoint

有什么想法吗?

1 个答案:

答案 0 :(得分:2)

你能把它变成局部变量吗?这是指标的作用。所以它是tf.get_variable("feature_embedding", collections=[tf.GraphKeys.LOCAL_VARIABLES], ...)