我有一个训练有素的检查站。现在,我正在尝试将此预训练模型恢复到当前网络。但是,变量名不同。 Tensorflow document表示使用字典,例如:
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2": v2})
但是,当前网络中的变量的定义如下:
with tf.variable_scope('a'):
b=tf.get_variable(......)
因此,变量名似乎是a/b
。
如何使字典像"v2": a/b
一样?
答案 0 :(得分:1)
您可以使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
获取当前图形中所有变量名称的列表。您还可以指定范围。
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')
您可以使用tf.train.list_variables(ckpt_file)
来获取检查点中所有变量的列表。
假设您的检查点中有变量b,并且您想以tf.variable_scope('a')
的名称加载到a/b
中。为此,您只需定义
with tf.variable_scope('a'):
b=tf.get_variable(......)
并加载
saver = tf.train.Saver({'v2': b})
with tf.Session() as sess:
saver.restore(sess, ckpt_file))
print(b)
这将输出
<tf.Variable 'a/b:0' shape dtype>
编辑:如前所述,您可以使用
vars_dict = {}
for var_current in tf.global_variables():
print(var_current)
print(var_current.op.name) # this gets only name
for var_ckpt in tf.train.list_variables(ckpt):
print(var_ckpt[0]) this gets only name
如果知道所有变量的确切名称,则可以分配所需的任何值,只要变量具有相同的形状和dtype 所以要得到一个命令
vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current
您可以显式地或以您认为合适的任何循环构造此字典。然后将其传递给保护程序
saver = tf.train.Saver(vars_dict)