在Tensorflow中,如何从.index或.data文件恢复变量? 我不想要整个模型,我只想要四个变量。
以下是我一直在尝试的内容:
save_dest = "my_path\\model_name" ## name of saved session
## Set 1:
w1 = tf.Variable([11,11,1,12])
w1_bias = tf.Variable(tf.zeros(12))
w2 = tf.Variable([7, 7, 12, 32])
w2_bias = tf.Variable(tf.zeros(32))
## Set 2:
# w1 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if
v.name=="w1"][0]
#w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if
v.name=="w1_bias"][0]
#w2 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if
v.name=="w2"][0]
#w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if
v.name=="w2_bias"][0]
## Set 3:
#w1 = tf.get_variable("w1",shape=[4])
#w1_bias = tf.get_variable("w1_bias",shape=[1])
#w2 = tf.get_variable("w2",shape=[4])
#w2_bias = tf.get_variable("w2_bias",shape=[1])
print("________________")
print_tensors_in_checkpoint_file(save_dest,tensor_name='',all_tensors=False)
print("________________")
# saver1 = tf.train.import_meta_graph(save_dest+".meta")
#saver1 = tf.train.Saver(["w1:0","w1_bias:0","w2:0","w2_bias:0"])
#tf.reset_default_graph()
#saver1 =
tf.train.Saver({"w1":w1,"w1_bias":w1_bias,"w2":w2,"w2_bias":w2_bias})
#saver1 = tf.train.Saver(var_list=["w1","w1_bias","w2","w2_bias"])
#saver1 = tf.train.Saver()
with tf.Session() as session:
saver1 = tf.train.Saver()
tf_valid_dataset = tf.constant(valid_dataset)
tf.global_variables_initializer().run()
#saver1 = tf.train.import_meta_graph(save_dest)
#saver1 =
tf.train.Saver({"w1":w1,"w1_bias":w1_bias,"w2":w2,"w2_bias":w2_bias})
#saver1.restore(session,save_dest)
saver1.restore(session, save_dest)
#w1 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if
v.name=="w1:0"][0]
#w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if v.name=="w1_bias:0"][0]
#w2 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if
v.name=="w2:0"][0]
#w1_bias = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if v.name=="w2_bias:0"][0]
g = tf.get_default_graph()
w1 = g.get_tensor_by_name("w1:0")
w1_bias = g.get_tensor_by_name("w1_bias:0")
w2 = g.get_tensor_by_name("w2:0")
w2_bias = g.get_tensor_by_name("w2_bias:0")
#preds = session.run(model(tf_valid_dataset))
#test_prediction = session.run(model(valid_dataset),
# feed_dict={keep_prob:1.0})
print(np.shape(preds), type(preds))
基本上我一直在阅读相关问题的答案(主要是如何恢复模型)并尝试实现它。但是,它总是会引起某种错误。当前save_dest中的唯一项是训练的参数矩阵,名为“w1”,“w1_bias”,“w2”,“w2_bias”。即
w1 (DT_FLOAT) [11,11,1,12]
w1_bias (DT_FLOAT) [12]
w2 (DT_FLOAT) [7,7,12,32]
w2_bias (DT_FLOAT) [32]
引发的错误包括variable_01不存在,或者维度4对象不能适合形状[11,11,1,12]等。我对如何解决这个问题感到茫然。任何帮助将不胜感激!