我使用tf.Estimator
来训练和评估我的模型。在评估期间,我想使用张量板投影仪可视化。为此,我需要使用我想要可视化的功能创建和填充变量。我的model_fn
如下所示:
def model_fn(...):
....
predictions = net(features, is_training=is_training)
...
if mode == ModeKeys.EVAL:
embedding_var = tf.get_variable("feature_embedding", ...)
update_embedding = embedding_var.assign(predictions)
....
问题是embedding_var
仅存在于评估图中。这会导致以下错误
NotFoundError (see above for traceback): Key feature_embedding not found in checkpoint
有什么想法吗?
答案 0 :(得分:2)
你能把它变成局部变量吗?这是指标的作用。所以它是tf.get_variable("feature_embedding", collections=[tf.GraphKeys.LOCAL_VARIABLES], ...)