Tensorflow重置或清除收集

时间:2017-07-28 06:11:11

标签: python-3.x tensorflow

在tensorflow中,我发现API tf.add_to_collcetion为集合添加一些值,如下面的代码。

def accuracy_rate(logits, labels):
    correct = tf.nn.in_top_k(logits, labels, 1)
    # Return the accuracy of true entries.
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    return accuracy 
with tf.Session() as sess:
    logits, labels = ...
    accuracy = accuracy_rate(logits, labels)
    tf.add_to_collection('total_accuracy', sess.run(accuracy))

我在API中找不到的是,如何清除我已经存储在一个集合中的所有值?

3 个答案:

答案 0 :(得分:5)

您可以使用tf.get_collection_ref获取可以清除的集合的可变引用(它只是一个python列表)。

答案 1 :(得分:0)

我认为这可能就是你要找的东西?

In [2]: import tensorflow as tf
In [3]: w = tf.Variable([[1,2,3], [4,5,6], [7,8,9], [3,1,5], [4,1,7]], collections=[tf.GraphKeys.WEIGHTS, tf.GraphKeys.GLOBAL_VARIABLES], dtype=tf.float32)
In [4]: params = tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
In [5]: del params[:]
In [6]: tf.get_collection_ref(tf.GraphKeys.WEIGHTS)                                                                                                                                                                  
Out[6]: []
In [10]: params = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
In [11]: params
Out[11]: [<tf.Variable 'Variable:0' shape=(5, 3) dtype=float32_ref>]

答案 2 :(得分:-1)

使用不同的tf.Graph()找到替代解决方案。