如何在Tensorflow中重新初始化预训练的加载模型?

时间:2017-10-04 12:04:23

标签: python tensorflow deep-learning

我使用以下代码加载了预训练模型(Model 1):

def load_seq2seq_model(sess):


    with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize the model with saved args
    model = Model1(saved_args)

    #Inititalize Tensorflow saver
    saver = tf.train.Saver()

    # Checkpoint 
    ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
    print('Loading model: ', ckpt.model_checkpoint_path)

    # Restore the model at the checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    return model

现在,我想从头开始训练另一个模型(Model 2),它将获取Model 1的输出。但为此,我需要定义一个会话并加载预先训练的模型并初始化模型tf.initialize_all_variables()。因此,预训练的模型也将被初始化。

任何人都可以告诉我如何在从预先训练的模型Model 2获得正确的输出后训练Model 1吗?

我正在尝试的内容如下 -

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(tf.initialize_all_variables())
    .... Rest of the training code goes here....

1 个答案:

答案 0 :(得分:1)

不需要初始化使用保护程序恢复的所有变量。因此,您可以使用tf.initialize_all_variables()来初始化第二个网络的权重,而不是使用tf.variables_initializer(var_list)

要获取第二个网络的所有权重列表,您可以在变量范围内创建Model 2网络:

with tf.variable_scope("model2"):
    model2 = Model2(...)

然后使用

model_2_variables_list = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, 
    scope="model2"
)

获取Model 2网络的变量列表。最后,您可以为第二个网络创建初始值:

init2 = tf.variables_initializer(model_2_variables_list)

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(init2)
    .... Rest of the training code goes here....