如何在保存的Tensorflow模型中使用局部变量

时间:2017-05-13 04:05:18

标签: python tensorflow

如何在保存的Tensorflow模型中使用局部变量?特别是,我想使用tf.contrib.metrics.streaming_auc。通常情况下,在使用此操作之前,它会sess.run(tf.local_variables_initializer())(它似乎有真正的正面局部变量等)。当我刚制作模型时,这种方法很好。但是,当我重新加载模型时,运行sess.run(tf.local_variables_initializer())会给我AttributeError: 'Tensor' object has no attribute 'initializer'。如果我尝试使用auc op而不初始化局部变量,我会得到FailedPreconditionError: Attempting to use uninitialized value auc/true_positives。我是否需要获取/保存对auc op中使用的局部变量的引用,并在重新加载模型时以某种方式直接初始化它们?

这是一个失败的例子:

inputs = tf.random_normal([10, 2])
labels = tf.reshape(tf.constant([1] * 5 + [0] * 5), (-1, 1))
logits = tf.layers.dense(inputs, 1)
pred = tf.sigmoid(logits)
auc_val, auc_op = tf.contrib.metrics.streaming_auc(pred, labels)

sess = tf.Session()
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

print(sess.run(auc_op)) # some terrible AUC, but it works
saver = tf.train.Saver()
saver.save(sess, 'test.model')

然后,在一个新的过程中:

sess = tf.Session()

saver = tf.train.import_meta_graph('test.model.meta')
saver.restore(sess, 'test.model')
auc_op = tf.get_default_graph().get_operation_by_name('auc/update_op')
print(sess.run(auc_op))
# FailedPreconditionError: Attempting to use uninitialized value auc/true_positives

另请注意,如果您sess.run(tf.local_variables_initializer())尝试初始化,则会出现上述其他错误。

一个人可以在加载后制作一个新的auc操作,但没有必要更好。此外,对于新的auc操作,您必须初始化局部变量,因此您必须从前一个auc操作中local_variables剩余的非实体化Tensors中过滤掉那些变量。

0 个答案:

没有答案