具有keras的自定义交替更新规则

时间:2017-07-09 11:47:34

标签: optimization tensorflow keras keras-layer keras-2

我想使用keras的交替更新规则。 即每批我想调用基于渐变的常规步骤,然后调用自定义步骤。

我考虑通过继承优化器或回调(并使用批量调用)来实现它。但是,两者都不会,因为它们都缺少批次数据和批次标签(我需要两者)。

如何使用keras实现自定义交替更新?

如果需要,我不介意直接调用tensorflow特定的方法,只要我可以继续使用keras框架包装的项目(使用model.fit,model.predict ..)

1 个答案:

答案 0 :(得分:0)

尝试创建自定义回调

import keras.callbacks as callbacks

class JSONMetrics(callbacks.Callback):

_model      = None
_each_epoch = None
_metrics    = None
_epoch      = None
_file_json  = None 

def __init__(self,model,each_epoch,logger=None):

    self._file_json = "file_log.json"
    self._model     = model
    self._each_epoch= each_epoch
    self._epoch     = 0
    self._metrics   = {'loss':[], 'acc':[]}

def on_epoch_begin(self, epoch, logs):
    # print('Epoch {0} begin'.format(epoch))
    try:
        with open(self._file_json, 'r') as f:   
            self._metrics = json.load(f)

def on_epoch_end(self, epoch, logs):
    self._logger.info('Nemesis: Epoch {0} end'.format(epoch))

    self._metrics['loss'].append(logs.get('loss'))
    self._metrics['acc'].append(logs.get('acc'))
    with open(self._file_json, 'w') as f:
        data = json.dump(self._metrics, f)

    if self._epoch % self._each_epoch == 0:

        file_name = 'weights%08d.h5' % self._epoch
        #print('Saving weights at {0} file'.format(file_name))
        self._model.save_weights(file_name)

    self._epoch += 1

你可以唤起self.model来解决你的问题并保存acc和loss。例如。