如何使用Tensorflow Estimator api获得一个纪元的总训练集损失

时间:2018-01-20 22:37:32

标签: tensorflow tensorflow-estimator

我正在寻找一种方法来实现学习率的搜索,如下所述:https://arxiv.org/pdf/1506.01186.pdf

为了实现这一点,我需要有一种方法来获得多个学习速率的单个时期的损失。我正在考虑创建一个SessionRunHook并简单地从每一步的损失中取平均值,它不会准确,因为最后一步很可能没有获得batch_size元素,但它应该是好的够了。

您是否已实施此类SessionRunHook,或知道如何在培训期间访问损失或/和批量大小。

2 个答案:

答案 0 :(得分:2)

我想出的是,它没有考虑到最后一个小批量可以更小但是由于我没有运行整个训练集,它应该是okey:

class RecordLossHook(tf.train.SessionRunHook):
  def __init__(self, loss_name):
    self.loss_name = loss_name

  def begin(self):
    self._loss_tensor = tf.get_default_graph().as_graph_element(self.loss_name+":0")
    self.loss_summed = 0
    self.batch_count = 0

  def before_run(self, run_context):
    return tf.train.SessionRunArgs(self._loss_tensor)

  def after_run(self, run_context,  run_values):
    self.loss_summed += run_values.results
    self.batch_count += 1 
    self.loss = self.loss_summed/self.batch_count 

然而,如果有人有更好的钩子来解释最后一个小批量,我很乐意接受这样的答案。

答案 1 :(得分:0)

也许这样也可以工作:

 function Person(args) {
    for(var key in args){
        this[key] = args[key];
     }
 }
相关问题