保存单个TensorFlow图层而不指定基础变量

时间:2017-12-12 01:23:25

标签: tensorflow

有没有办法保存特定的tf.layer而无需指定其所有基础变量?

Tensorflow允许保存单个变量:

saver = tf.train.Saver(var_list={"varName": varName})
saver.save(sess, "path")

但是,这不适用于图层。使用tf.get_collection获取与特定图层关联的变量,然后调用Saver构造函数会导致错误:

Saver(var_list={"varName": variableCollection})

ValueError: Slices must all be slices: <tf.Variable 'vars1/kernel:0' shape=(1, 1) dtype=float32_ref>

我所知道的唯一解决方案是仅传递图层的单个变量,但如果我不必遍历图层中的所有变量来保存它,那么它将更加便利。

Saver(var_list={"varName": variableCollection[0]})

1 个答案:

答案 0 :(得分:0)

如果使用面向对象的图层(如tf.layers.Dense),可以调用layer.variables来获取变量,然后将它们传递给保护程序。