Python / Keras - 每N批次后保存模型权重

时间:2017-05-05 01:14:04

标签: keras

我是Python和Keras的新手,我已经成功构建了一个神经网络,可以在每个Epoch之后保存重量文件。但是,我想要更多粒度(我在时间序列中可视化层权重分布)并且希望在每N个批次之后保存权重,而不是每个时期。

有人有任何建议吗?

2 个答案:

答案 0 :(得分:12)

您可以创建自己的回调(https://keras.io/callbacks/)。类似的东西:

function Upload(const aFilePath: string): boolean;
var
  vData: TMultipartFormData; // uses System.Net.Mime
  vHTTP: THTTPClient; // uses System.Net.HttpClient
  vCRC: cardinal;
  vURL: string;
  vResp: TStringStream;
begin
  vURL := 'PHP url';
  vResp := TStringStream.Create('');
  vData := TMultipartFormData.Create();
  vHTTP := THTTPClient.Create;
  try
    try
      vData.AddField('version', MyVerField.ToString);
      vData.AddField('crc', MyCRC.ToString);

      vData.AddFile('db_file', aFilePath);
      Result := vHTTP.Post(vURL, vData, vResp).StatusCode = 200;

      if Result then 
        Result := vResp.DataString.ContentAsString().Contains('"result":true');          

    except
      Result := false;
    end;
  finally
    vHTTP.Free;
    vData.Free;
    vResp.Free; 
  end;
end;

我使用from keras.callbacks import Callback class WeightsSaver(Callback): def __init__(self, N): self.N = N self.batch = 0 def on_batch_end(self, batch, logs={}): if self.batch % self.N == 0: name = 'weights%08d.h5' % self.batch self.model.save_weights(name) self.batch += 1 代替提供的self.batch参数,因为后者在每个纪元重新开始为0.

然后将其添加到您的健康呼叫中。例如,每5批保存一次重量:

batch

答案 1 :(得分:0)

如grovina所述,您可以创建自己的回调。 https://keras.io/callbacks/ 在此处查看“回调”的来源:https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L148

在Callback类下,有许多功能可以满足您的需求。在您的情况下,如果要每N个时间段保存一次模型,则定义函数“ on_epoch_end”。

示例代码

class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
        name = ('weights%04d.hdf5') % self.epoch
        self.model.save_weights(name)
    self.epoch += 1

callbacks_list = [WeightsSaver(10)] #save every 10 models
model.fit(train_X,train_Y,epochs=n_epochs,batch_size=batch_size, callbacks=callbacks_list )