在tensorflow中,您可以定义自己的集合名称吗?

时间:2016-05-20 00:00:34

标签: collections tensorflow

我搜索了tensorflow的API文档中的所有资源,但无法找到任何指示。 似乎在使用get_variable()时,我可以为集合术语设置一个特定的名称,如:

x=tf.get_variable('x',[2,2],collections='my_scope')

但在执行时只获取空列表:

tf.get_collection('my_scope')

2 个答案:

答案 0 :(得分:6)

集合 S 需要list集合名称。

>>x = tf.get_variable('x',[2,2], collections=['my_scope'])
>>tf.get_collection('my_scope')

[<tensorflow.python.ops.variables.Variable at 0x10d8e1590>]

请注意,如果您使用它,其他一些操作可能会产生副作用。 像tf.all_variables()一样不起作用,因此tf.initialize_all_variables()也不会看到你的变量。修复它的一种方法是指定默认集合。

>>x = tf.get_variable('x',[2,2], collections=['my_scope', tf.GraphKeys.VARIABLES])

但事情开始变得乏味。

答案 1 :(得分:1)

实际上,您可以使用tf.get_collection创建新集合:

tf.get_collection('my_collection')
var = tf.get_variable('var', [2, 2], initializer=tf.constant_initializer())
tf.add_to_collection('my_collection', var)