在张量流图中保存单个神经网络的权重

时间:2017-10-06 23:28:07

标签: tensorflow

如何在张量流图中保存单个神经网络的权重,以便可以将其加载到具有相同架构的网络中的不同程序中?

我的训练代码仅需要3个其他神经网络用于训练过程。如果我使用saver.save(sess, 'my-model)',它是否会保存张量流图中的所有变量?这对我的用例来说似乎不正确。

也许这来自于我对张量流应该如何工作的误解。我是否正确地解决了这个问题?

1 个答案:

答案 0 :(得分:2)

最好的方法是使用tensorflow变量范围。假设您有model_1,model_2和model_3,并且您只想保存model_1:

首先,在训练代码中定义模型:

model_1_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="model_1")
saver = tf.train.Saver(model_1_variables)

接下来,为model_1的变量定义保护程序:

saver.save(sess, 'my-model')

在训练时,您可以像上面提到的那样保存检查点:

with tf.variable_scope('model_1'):
    model one declaration here
    ...
model_1_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="model_1")
saver = tf.train.Saver(model_1_variables)
sess = tf.Session()
saver.restore(sess, 'my-model')` 

培训结束后,如果要在评估代码中恢复权重,请确保以相同的方式定义model_1和保护程序:

class Node {
    private:
       int key;
       Node* left, right;

    public:
       Node() {
           left = NULL;
           right = NULL;
       }

       Node(int data) : Node() {
           this->key=data;   
       }
};