是否可以在tensorflow中重命名给定模型的变量范围?
例如,我根据教程
为MNIST数字创建了逻辑回归模型_mm_loadu_ps
现在,我想在with tf.variable_scope('my-first-scope'):
NUM_IMAGE_PIXELS = 784
NUM_CLASS_BINS = 10
x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])
W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
saver = tf.train.Saver([W, b])
... # some training happens
saver.save(sess, 'my-model')
变量范围内重新加载已保存的模型,然后将所有内容再次保存到新文件中,并保存在'my-first-scope'
的新变量范围内。
答案 0 :(得分:24)
根据keveman的回答,我创建了一个python脚本,您可以执行该脚本来重命名任何TensorFlow检查点的变量:
https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96
您可以替换变量名称中的子字符串,并为所有名称添加前缀。使用
调用脚本python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir
带有可选参数
--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run
这是脚本的核心功能:
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run=False):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
示例:
python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/
会将变量scope1/Variable1
重命名为abc/scope1/model/Variable1
。
答案 1 :(得分:9)
您可以按照以下方式使用tf.contrib.framework.list_variables
和tf.contrib.framework.load_variable
来实现目标:
with tf.Graph().as_default(), tf.Session().as_default() as sess:
with tf.variable_scope('my-first-scope'):
NUM_IMAGE_PIXELS = 784
NUM_CLASS_BINS = 10
x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])
W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
saver = tf.train.Saver([W, b])
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
vars = tf.contrib.framework.list_variables('.')
with tf.Graph().as_default(), tf.Session().as_default() as sess:
new_vars = []
for name, shape in vars:
v = tf.contrib.framework.load_variable('.', name)
new_vars.append(tf.Variable(v, name=name.replace('my-first-scope', 'my-second-scope')))
saver = tf.train.Saver(new_vars)
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-new-model')
答案 2 :(得分:0)
另一个用于重命名变量并以此方式更改其作用域名称的简单脚本:
import tensorflow as tf
OLD_CHECKPOINT_FILE = "model.ckpt"
NEW_CHECKPOINT_FILE = "model_renamed.ckpt"
vars_to_rename = {
"scope_1/var1": "scope_2/var1",
"scope_1/var2": "scope_2/var2",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(OLD_CHECKPOINT_FILE)
for old_name in reader.get_variable_to_shape_map():
if old_name in vars_to_rename:
new_name = vars_to_rename[old_name]
else:
new_name = old_name
new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))
init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)
with tf.Session() as sess:
sess.run(init)
saver.save(sess, NEW_CHECKPOINT_FILE)