将保存在一个模型中的单个可变张量恢复为另一模型中的可变张量-Tensorflow

时间:2018-07-14 08:00:30

标签: tensorflow tensor

在tensorflow 1.3.0 GPU上运行。 我已经用TF训练了模型,并使用以下命令仅保存了一个变量张量:

embeddings = tf.Variable(tf.random_uniform([4**kmer_len, embedding_size], -0.04, 0.04), name='Embeddings')
more code, variables...
saver = tf.train.Saver({"Embeddings": embeddings})  # saving only embeddings variable
some more code, training model...
saver.save(ses, './embeddings/embedding_mat')       # saving the variable 

现在,我在不同的文件中使用了不同的模型,我只想将保存的单个 embeddings 变量重新分配给它。问题在于这个新模型还有更多变量。 现在,当我尝试通过以下方式还原变量时:

embeddings = tf.Variable(tf.random_uniform([4**kmer_len_emb, embedding_size], -0.04, 0.04), name='Embeddings')
dense1 = tf.layers.dense(inputs=kmer_flattened, units=200, activation=tf.nn.relu, use_bias=True)  
ses = tf.Session()
init = tf.global_variables_initializer()
ses.run(init)
saver = tf.train.Saver()
saver.restore(ses, './embeddings/embedding_mat')

我在检查点中找不到错误。 关于如何处理这个有什么想法吗? 谢谢

2 个答案:

答案 0 :(得分:1)

这是因为它找不到dense1检查点。试试这个:

all_var = tf.global_variables()

var_to_restore = [v for v in all_var if v.name == 'Embeddings:0']
ses.run(init)

saver = tf.train.Saver(var_to_restore)
saver.restore(ses, './embeddings/embedding_mat')

答案 1 :(得分:1)

您必须仅在该变量上创建Saver的实例:

saver = tf.train.Saver(var_list=[embeddings])

这是对您的Saver实例说的是只恢复/保存该图的特定变量,否则它将尝试恢复/保存该图的所有变量。