Tensorflow:保存先前创建的模型的变量子集

时间:2017-03-29 19:11:38

标签: python machine-learning tensorflow neural-network

我创建了一个带有一堆变量的模型(模型A)。我计划使用模型A中的一些层在新模型(模型B)上使用模型A进行传递学习。但是,模型B具有与模型A相同的体系结构,因此我无法从模型加载所有变量在运行Model B之前,否则会出现命名错误等等。所以,我试图创建一个新的ckpt文件,只存储我想要的模型A的权重。然后我将使用这个新的ckpt文件加载到模型B.我有以下内容:

sess = tf.Session()
saver = tf.train.import_meta_graph('ModelA.ckpt.meta')
saver.restore(sess, 'ModelA.ckpt')

# I did not explicity name my variables in model A so I am just placing them in the list and taking the ones I want

store_list = []
for v in tf.trainable_variables():
    store_list.append(v)

var_list={"W_1": store_list[0], "b_1": store_list[1]}
v2_saver=tf.train.Saver(var_list)
sess.run(tf.global_variables_initializer())
v2_saver.save(sess, 'model_A_subset.ckpt')

但是当我恢复model_A_subset.ckpt时,我仍然拥有ModelA.ckpt中的所有变量。难道我做错了什么?有没有办法可以轻松删除ModelA.ckpt中我不想要的变量并使用它?

1 个答案:

答案 0 :(得分:0)

您确定检查点中有不必要的变量吗?我问,因为在恢复检查点之前,您需要创建图表,如果您正在创建一个包含A的所有变量的图表,那么您将遇到此问题。

要检查检查点并查看实际情况,请尝试inspect checkpoint tool