如果中间出现问题,不浪费数小时/数天的网络训练是一种好习惯吗?
答案 0 :(得分:2)
我使用了一个自定义回调,该回调存储了最后一个纪元,权重,损失等,以便随后恢复:
class StatefulCheckpoint(ModelCheckpoint):
"""Save extra checkpoint data to resume training."""
def __init__(self, weight_file, state_file=None, **kwargs):
"""Save the state (epoch etc.) along side weights."""
super().__init__(weight_file, **kwargs)
self.state_f = state_file
self.state = dict()
if self.state_f:
# Load the last state if any
try:
with open(self.state_f, 'r') as f:
self.state = json.load(f)
self.best = self.state['best']
except Exception as e: # pylint: disable=broad-except
print("Skipping last state:", e)
def on_epoch_end(self, epoch, logs=None):
"""Saves training state as well as weights."""
super().on_epoch_end(epoch, logs)
if self.state_f:
state = {'epoch': epoch+1, 'best': self.best,
'hostname': self.hostname}
state.update(logs)
state.update(self.params)
with open(self.state_f, 'w') as f:
json.dump(state, f)
def get_last_epoch(self, initial_epoch=0):
"""Return last saved epoch if any, or return default argument."""
return self.state.get('epoch', initial_epoch)
仅当您的时期是合理的时间时,此方法才有效。 1小时,但干净且符合Keras API。
答案 1 :(得分:1)
一个简单的解决方案是使用日志记录并定期将模型序列化到磁盘。您最多可以保留5个版本的网络,以避免用尽磁盘内存。
Python具有出色的logging utilities,您可能会发现pickle对序列化模型很有用。