有没有办法保存特定的tf.layer
而无需指定其所有基础变量?
Tensorflow允许保存单个变量:
saver = tf.train.Saver(var_list={"varName": varName})
saver.save(sess, "path")
但是,这不适用于图层。使用tf.get_collection
获取与特定图层关联的变量,然后调用Saver
构造函数会导致错误:
Saver(var_list={"varName": variableCollection})
ValueError: Slices must all be slices: <tf.Variable 'vars1/kernel:0' shape=(1, 1) dtype=float32_ref>
我所知道的唯一解决方案是仅传递图层的单个变量,但如果我不必遍历图层中的所有变量来保存它,那么它将更加便利。
Saver(var_list={"varName": variableCollection[0]})
答案 0 :(得分:0)
如果使用面向对象的图层(如tf.layers.Dense),可以调用layer.variables来获取变量,然后将它们传递给保护程序。