我已经在TensorFlow中训练了一个模型,现在我想想看哪些输入最大程度地激活输出。我想知道最干净的方法是什么。
我曾经想过通过创建一个可训练的输入变量来做到这一点,我可以为每次运行分配一次。然后通过使用适当的损失函数并使用包含仅包含此输入变量的var_list的优化器,我将更新此输入变量直到收敛。即。
trainable_input = tf.get_variable(
'trainable_input',
shape=data_op.get_shape(),
dtype=data_op.dtype,
initializer=tf.zeros_initializer(),
trainable=True,
collections=[tf.GraphKeys.LOCAL_VARIABLES])
trainable_input_assign_op = tf.assign(trainable_input, data_op)
data_op = trainable_input
# ... run the rest of the graph building code here, now with a trainable input
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
# loss_op is defined on one of the outputs
train_op = optimizer.minimize(loss_op, var_list=[trainable_input])
然而,当我这样做时,我遇到了问题。如果我尝试使用Supervisor恢复预先训练的图形,那么它自然会抱怨AdamOptimizer创建的新变量不存在于我尝试恢复的图形中。我可以通过使用get_slots获取AdamOptimizer创建的变量并手动将这些变量添加到tf.GraphKeys.LOCAL_VARIABLES集合中来解决这个问题,但是它感觉非常hacky并且我不确定这会产生什么后果。我还可以从传递给Supervisor的Saver中明确地排除这些变量,而不将它们添加到tf.GraphKeys.LOCAL_VARIABLES集合中,但是我得到一个例外,它们没有被Supervisor正确初始化:
File "/usr/local/lib/python3.5/site-packages/tensorflow/python/training/supervisor.py", line 973, in managed_session
self.stop(close_summary_writer=close_summary_writer)
File "/usr/local/lib/python3.5/site-packages/tensorflow/python/training/supervisor.py", line 801, in stop
stop_grace_period_secs=self._stop_grace_secs)
File "/usr/local/lib/python3.5/site-packages/tensorflow/python/training/coordinator.py", line 386, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python3.5/site-packages/six.py", line 686, in reraise
raise value
File "/usr/local/lib/python3.5/site-packages/tensorflow/python/training/supervisor.py", line 962, in managed_session
start_standard_services=start_standard_services)
File "/usr/local/lib/python3.5/site-packages/tensorflow/python/training/supervisor.py", line 719, in prepare_or_wait_for_session
init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
File "/usr/local/lib/python3.5/site-packages/tensorflow/python/training/session_manager.py", line 280, in prepare_session
self._local_init_op, msg))
RuntimeError: Init operations did not make model ready. Init op: init, init fn: None, local_init_op: name: "group_deps_5"
op: "NoOp"
input: "^init_1"
input: "^init_all_tables"
, error: Variables not initialized: trainable_input/trainable_input/Adam, trainable_input/trainable_input/Adam_1
我不确定为什么这些变量没有被初始化,因为我之前使用过这种技术从恢复过程中排除了一些变量(GLOBAL和LOCAL),它们似乎按预期初始化了。
简而言之,我的问题是,是否有一种简单的方法可以向图形添加优化器并执行检查点恢复(检查点不包含优化器变量),而不必使用优化器的内部。如果这不可能,那么将优化器变量添加到LOCAL_VARIABLES集合是否有任何缺点?
答案 0 :(得分:0)
使用slim库时会发生同样的错误。实际上,slim.learning.train()
内部使用了tf.train.Supervisor
。我希望我对GitHub issue的回答可以帮助你的主管问题。
我和你有同样的问题。我通过以下两个步骤来解决它。
saver
传递给slim.learning.train()
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
saver = tf.train.Saver(var_list=optimistic_restore_vars(ckpt.model_checkpoint_path) if ckpt else None)
其中函数optimistic_restore_vars定义为
def optimistic_restore_vars(model_checkpoint_path):
reader = tf.train.NewCheckpointReader(model_checkpoint_path)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables() if var.name.split(':')[0] in saved_shapes])
restore_vars = []
name2var = dict(zip(map(lambda x:x.name.split(':')[0], f.global_variables()), tf.global_variables()))
with tf.variable_scope('', reuse=True):
for var_name, saved_var_name in var_names:
curr_var = name2var[saved_var_name]
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
return restore_vars
```
local_init_op
传递给slim.learning.train()
以初始化添加的新变量local_init_op = tf.global_variables_initializer()
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
saver = tf.train.Saver(var_list=optimistic_restore_vars ckpt.model_checkpoint_path) if ckpt else None)
local_init_op = tf.global_variables_initializer()
###########################
# Kicks off the training. #
###########################
learning.train(
train_tensor,
saver=saver,
local_init_op=local_init_op,
logdir=FLAGS.train_dir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
init_fn=_get_init_fn(),
summary_op=summary_op,
number_of_steps=FLAGS.max_number_of_steps,
log_every_n_steps=FLAGS.log_every_n_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
sync_optimizer=optimizer if FLAGS.sync_replicas else None
)