在tensorflow中,我发现API tf.add_to_collcetion为集合添加一些值,如下面的代码。
def accuracy_rate(logits, labels):
correct = tf.nn.in_top_k(logits, labels, 1)
# Return the accuracy of true entries.
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
return accuracy
with tf.Session() as sess:
logits, labels = ...
accuracy = accuracy_rate(logits, labels)
tf.add_to_collection('total_accuracy', sess.run(accuracy))
我在API中找不到的是,如何清除我已经存储在一个集合中的所有值?
答案 0 :(得分:5)
您可以使用tf.get_collection_ref
获取可以清除的集合的可变引用(它只是一个python列表)。
答案 1 :(得分:0)
我认为这可能就是你要找的东西?
In [2]: import tensorflow as tf
In [3]: w = tf.Variable([[1,2,3], [4,5,6], [7,8,9], [3,1,5], [4,1,7]], collections=[tf.GraphKeys.WEIGHTS, tf.GraphKeys.GLOBAL_VARIABLES], dtype=tf.float32)
In [4]: params = tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
In [5]: del params[:]
In [6]: tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
Out[6]: []
In [10]: params = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
In [11]: params
Out[11]: [<tf.Variable 'Variable:0' shape=(5, 3) dtype=float32_ref>]
答案 2 :(得分:-1)
使用不同的tf.Graph()找到替代解决方案。