如何从tf.train.string_input_producer获取epoch num信息

时间:2016-08-23 12:24:32

标签: tensorflow

如果使用string_input_producer读取文件,例如

filename_queue = tf.train.string_input_producer(
  files, 
  num_epochs=num_epochs,
  shuffle=shuffle)

如何在训练期间获取纪元数信息(我希望在培训期间显示此信息) 我试过下面的

run 
tf.get_default_graph().get_tensor_by_name('input_train/input_producer/limit_epochs/epochs:0')

将始终与限制纪元数

相同
run
tf.get_default_graph().get_tensor_by_name('input_train/input_producer/limit_epochs/CountUpTo:0')

每次都会加1 ..

在训练期间,两者都无法获得正确的纪元。

另一件事是,如果从现有模型重新训练,我可以获得已经训练过的纪元数据吗?

1 个答案:

答案 0 :(得分:1)

我认为这里正确的方法是定义一个传递给优化器的global_step变量(或者你可以手动增加它)。

TensorFlow Mechanics 101教程提供了一个示例:

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

每次global_step运行时,train_op都会递增。由于您知道数据集的大小和批量大小,因此您将知道您目前所处的时代。

使用tf.train.Saver()保存模型时,global_step变量也会保存。恢复模型后,只需拨打global_step.eval()即可返回上次停止的步长值。

我希望这有帮助!