在Tensorflow中,我们可以使用Between-graph Replication
构建和创建多个Tensorflow会话,以进行分布式培训。 MonitoredTrainingSession()
协调多个Tensorflow会话,checkpoint_dir
有一个参数MonitoredTrainingSession()
来恢复Tensorflow会话/图表。现在我有以下问题:
tf.train.Saver()
的对象按saver.restore(...)
恢复Tensorflow图。但是,我们如何使用MonitoredTrainingSession()
?MonitoredTrainingSession()
如何使用测试(或预测)模式?我阅读了Tensorflow Doc,但未找到这两个问题的答案。如果有人有解决方案我真的很感激。谢谢!
答案 0 :(得分:3)
简短回答:
答案很长:
我将更新我的答案,因为我自己可以更好地了解tf.train.MonitoredSession可以做些什么(tf.train.MonitoredTrainingSession只是创建一个专门版本的tf.train.MonitoredSession,可以看到在source code)。
以下示例代码显示了如何每5秒将检查点保存到' ./ ckpt_dir'。中断后,它将在上次保存的检查点重新启动:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
为了实现这一点,MonitoredTrainingSession中发生的事情实际上是三件事:
为了使其工作,必须将tf.train.CheckpointSaverHook和tf.train.ChiefSessionCreator传递给checkpoint目录和scaffold的相同引用。如果上面的例子中的tf.train.MonitoredTrainingSession及其参数是用上面的3个组件实现的,那么它看起来像这样:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
为了进行训练+交叉验证会话,您可以将tf.train.MonitoredSession.run_step_fn()与partial一起使用,该部分运行会话调用而不调用任何挂钩。这看起来是你为 n 迭代训练你的模型,然后运行你的测试集,重新初始化你的迭代器并返回训练你的模型等等。当然,你必须设置你的要执行此操作时要重用的变量= tf.AUTO_REUSE。在代码中执行此操作的方法如下所示:
from functools import partial
# Build model
...
with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
...
...
def step_fn(fetches, feed_dict, step_context):
return step_context.session.run(fetches=fetches, feed_dict=feed_dict)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=...,
save_checkpoint_steps=...,
hooks=[...],
...
) as mon_sess:
# Initialize iterators (assuming tf.Databases are used)
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
while not mon_sess.should_stop():
# Train session
for i in range(n):
try:
train_results = mon_sess.run(<train_fetches>)
except Exception as e:
break
# Test session
while True:
try:
test_results = mon_sess.run(<test_fetches>)
except Exception as e:
break
# Reinitialize parameters
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
partial函数只是在step_fn上执行currying(函数式编程中的经典函数),在mon_sess.run_step_fn()中使用。上面的整个代码尚未经过测试,您可能必须在开始测试会话之前重新初始化train_it,但希望现在很清楚如何在同一个运行中运行训练集和验证集。如果你想在同一个图中同时绘制训练曲线和测试曲线,这可以与张量板的custom_scalar tool一起使用。
最后,这是我能够做到的这个功能的最佳实现,我个人希望tensorflow将来更容易实现这个功能,因为它非常繁琐,可能效率不高。我知道有Estimator这样的工具可以运行train_and_evaluate函数,但是由于这会在每次列车验证和交叉验证运行之间重建图形,如果你只运行一个,它的效率非常低单台电脑。我在某处看到Keras + tf具有此功能,但由于我不使用Keras + tf,因此这不是一个选项。无论如何,我希望这可以帮助其他人在那里挣扎同样的事情!
答案 1 :(得分:0)
您应该导入元图,然后恢复模型。 从下面的片段中获取灵感,应该适合你。
self.sess = tf.Session()
ckpt = tf.train.latest_checkpoint("location-of/model")
saver = tf.train.import_meta_graph(ckpt + '.meta', clear_devices=True)
saver.restore(self.sess, ckpt)
答案 2 :(得分:-1)