我有一个使用MonitoredTrainingSession进行分布式计算的RNN。我正在使用global_step来识别每个工作者应该采用哪一批输入数据。
我在创建会话之前定义了张量
global_step_tensor = tf.Variable(0, dtype=tf.int32, trainable=False, name=‘global_step’)
...
minimise = optimiser.minimize(loss, name=‘adam_opt’, global_step=‘global_step’)
with tf.train.MonitoredTrainingSession(...) as sess:
graph=tf.get_default_graph()
curr_step=sess.run(global_step_tensor)
print(curr_step) #gives 366
我认为变量仅在评估优化器时递增?为什么从366开始?
编辑我的群集定义为一个ps和两个worker。目前,在我测试时,所有三个都通过不同的端口在同一主机上运行。
答案 0 :(得分:2)
根据文档,MonitoredTrainingSession
有几个默认参数可以自动生成检查点:
save_checkpoint_secs
:检查点的频率(以秒为单位) 使用默认检查点保护程序保存。如果设置了save_checkpoint_secs 为None,则不使用默认检查点保护程序。
save_summaries_steps
:全局步数的频率 使用默认的摘要保护程序将摘要写入磁盘。如果 save_summaries_steps和save_summaries_secs都设置为None, 然后不使用默认的摘要保护程序。默认为100。
save_summaries_secs
:摘要的频率(以秒为单位) 使用默认的摘要保护程序写入磁盘。如果两者 save_summaries_steps和save_summaries_secs设置为None,然后是 未使用默认摘要保护程序。默认未启用。
也许这就是您当前批次不再是0
的原因。