如果发生故障,如何在Keras中不丢失以往的跑步记录?

时间:2018-09-20 20:51:53

标签: python tensorflow neural-network keras

如果中间出现问题,不浪费数小时/数天的网络训练是一种好习惯吗?

2 个答案:

答案 0 :(得分:2)

我使用了一个自定义回调,该回调存储了最后一个纪元,权重,损失等,以便随后恢复:

class StatefulCheckpoint(ModelCheckpoint):
  """Save extra checkpoint data to resume training."""
  def __init__(self, weight_file, state_file=None, **kwargs):
    """Save the state (epoch etc.) along side weights."""
    super().__init__(weight_file, **kwargs)
    self.state_f = state_file
    self.state = dict()
    if self.state_f:
      # Load the last state if any
      try:
        with open(self.state_f, 'r') as f:
          self.state = json.load(f)
        self.best = self.state['best']
      except Exception as e: # pylint: disable=broad-except
        print("Skipping last state:", e)

  def on_epoch_end(self, epoch, logs=None):
    """Saves training state as well as weights."""
    super().on_epoch_end(epoch, logs)
    if self.state_f:
      state = {'epoch': epoch+1, 'best': self.best,
               'hostname': self.hostname}
      state.update(logs)
      state.update(self.params)
      with open(self.state_f, 'w') as f:
        json.dump(state, f)

  def get_last_epoch(self, initial_epoch=0):
    """Return last saved epoch if any, or return default argument."""
    return self.state.get('epoch', initial_epoch)

仅当您的时期是合理的时间时,此方法才有效。 1小时,但干净且符合Keras API。

答案 1 :(得分:1)

一个简单的解决方案是使用日志记录并定期将模型序列化到磁盘。您最多可以保留5个版本的网络,以避免用尽磁盘内存。

Python具有出色的logging utilities,您可能会发现pickle对序列化模型很有用。