是否可以将权重加载到新会话?我想通过从检查点还原权重来获取一个模型的权重,以数组的形式获取变量,并将其提供给新的会话。这可能吗?这是我目前正在尝试的代码:
def __init__(self, ):
self.all_values = []
def importer(self):
new_graph = tf.Graph()
ckpt = tf.train.get_checkpoint_state('./models/Control')
with tf.Session(graph=new_graph) as sess1:
# import control graph and ckpt with trained values
saver = tf.train.import_meta_graph("./models/control-graph.meta")
saver.restore(sess1, ckpt.model_checkpoint_path)
self.all_values = sess1.run(tf.global_variables())
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
sess1.close()
all_var_names = []
for i in range(len(all_vars)):
all_var_names.append(all_vars[i].name)
np.save("./tmp/all_var_names", all_var_names)
def loader(self):
all_var_names = np.load("./tmp/all_var_names.npy")
new_graph = tf.Graph()
ckpt = tf.train.get_checkpoint_state('./models/TestRun1')
with tf.Session(graph=new_graph) as sess:
saver = tf.train.import_meta_graph("./models/control-graph.meta")
all_ops = []
all_names = []
for i in range(len(all_var_names)):
all_names.append(all_var_names[i].split(":")[0])
all_ops.append(new_graph.get_operation_by_name(all_names[i]).outputs[0])
test = zip(all_ops, self.all_values)
d = dict(test)
#sess.run(tf.global_variables_initializer())
sess.run(all_ops, feed_dict=d)
print(sess.run(tf.trainable_variables()))
saver.save(sess, './models/Test/output.cptk')
我尝试过的所有替代方法都使我接近,但从不满足我的要求。我不能使用占位符?因为我尝试根据导入的图形和ckpt值动态地执行此操作。我只是想将张量值加载到会话中。我尝试过恢复未经训练的会话,并使用优化的权重更新变量,但没有运气,有什么建议可取吗?