Tensorflow:如何恢复训练变量,但不是完整模型?

时间:2017-10-09 02:11:07

标签: machine-learning tensorflow conv-neural-network

在Tensorflow中,如何从.index或.data文件恢复变量? 我不想要整个模型,我只想要四个变量。

以下是我一直在尝试的内容:

save_dest = "my_path\\model_name" ## name of saved session

## Set 1:
w1 = tf.Variable([11,11,1,12])
w1_bias = tf.Variable(tf.zeros(12))  
w2 = tf.Variable([7, 7, 12, 32])
w2_bias = tf.Variable(tf.zeros(32))

## Set 2:
#   w1 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 
v.name=="w1"][0] 
#w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 
v.name=="w1_bias"][0] 
#w2 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 
v.name=="w2"][0] 
#w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 
v.name=="w2_bias"][0] 

## Set 3:
#w1 = tf.get_variable("w1",shape=[4])
#w1_bias = tf.get_variable("w1_bias",shape=[1])  
#w2 = tf.get_variable("w2",shape=[4])
#w2_bias = tf.get_variable("w2_bias",shape=[1])
print("________________")
print_tensors_in_checkpoint_file(save_dest,tensor_name='',all_tensors=False)
print("________________")

# saver1 = tf.train.import_meta_graph(save_dest+".meta")   
#saver1 = tf.train.Saver(["w1:0","w1_bias:0","w2:0","w2_bias:0"])
#tf.reset_default_graph()
#saver1 = 
tf.train.Saver({"w1":w1,"w1_bias":w1_bias,"w2":w2,"w2_bias":w2_bias})
#saver1 = tf.train.Saver(var_list=["w1","w1_bias","w2","w2_bias"])
#saver1 = tf.train.Saver()
with tf.Session() as session:
    saver1 = tf.train.Saver()    
    tf_valid_dataset = tf.constant(valid_dataset)
    tf.global_variables_initializer().run()
    #saver1 = tf.train.import_meta_graph(save_dest)
    #saver1 = 
tf.train.Saver({"w1":w1,"w1_bias":w1_bias,"w2":w2,"w2_bias":w2_bias})

    #saver1.restore(session,save_dest)
    saver1.restore(session, save_dest)
    #w1 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 
                 v.name=="w1:0"][0] 
    #w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 
                 if v.name=="w1_bias:0"][0] 
    #w2 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 
                v.name=="w2:0"][0] 
    #w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 
                if v.name=="w2_bias:0"][0] 

    g = tf.get_default_graph()

    w1 = g.get_tensor_by_name("w1:0")
    w1_bias = g.get_tensor_by_name("w1_bias:0")
    w2 = g.get_tensor_by_name("w2:0")
    w2_bias = g.get_tensor_by_name("w2_bias:0")


    #preds = session.run(model(tf_valid_dataset))
    #test_prediction = session.run(model(valid_dataset),
    #                        feed_dict={keep_prob:1.0})

    print(np.shape(preds), type(preds))

基本上我一直在阅读相关问题的答案(主要是如何恢复模型)并尝试实现它。但是,它总是会引起某种错误。当前save_dest中的唯一项是训练的参数矩阵,名为“w1”,“w1_bias”,“w2”,“w2_bias”。即

 w1 (DT_FLOAT) [11,11,1,12]  
 w1_bias (DT_FLOAT) [12]  
 w2 (DT_FLOAT) [7,7,12,32]  
 w2_bias (DT_FLOAT) [32]  

引发的错误包括variable_01不存在,或者维度4对象不能适合形状[11,11,1,12]等。我对如何解决这个问题感到茫然。任何帮助将不胜感激!

0 个答案:

没有答案