我使用以下代码加载了预训练模型(Model 1
):
def load_seq2seq_model(sess):
with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)
# Initialize the model with saved args
model = Model1(saved_args)
#Inititalize Tensorflow saver
saver = tf.train.Saver()
# Checkpoint
ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
print('Loading model: ', ckpt.model_checkpoint_path)
# Restore the model at the checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
return model
现在,我想从头开始训练另一个模型(Model 2
),它将获取Model 1
的输出。但为此,我需要定义一个会话并加载预先训练的模型并初始化模型tf.initialize_all_variables()
。因此,预训练的模型也将被初始化。
任何人都可以告诉我如何在从预先训练的模型Model 2
获得正确的输出后训练Model 1
吗?
我正在尝试的内容如下 -
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(tf.initialize_all_variables())
.... Rest of the training code goes here....
答案 0 :(得分:1)
不需要初始化使用保护程序恢复的所有变量。因此,您可以使用tf.initialize_all_variables()
来初始化第二个网络的权重,而不是使用tf.variables_initializer(var_list)
。
要获取第二个网络的所有权重列表,您可以在变量范围内创建Model 2
网络:
with tf.variable_scope("model2"):
model2 = Model2(...)
然后使用
model_2_variables_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope="model2"
)
获取Model 2
网络的变量列表。最后,您可以为第二个网络创建初始值:
init2 = tf.variables_initializer(model_2_variables_list)
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(init2)
.... Rest of the training code goes here....