load_model是否加载优化器状态?

时间:2019-06-28 10:49:02

标签: python tensorflow keras

我通过keras.models.load_model加载了通过model.save保存的模型

然后我正在尝试打印优化器状态:

from keras import backend as K
tf_session = K.get_session()
print(model.optimizer.iterations.eval(session=tf_session))
print(model.optimizer.lr.eval(session=tf_session))

哪些印刷品:

<tf.Variable 'Adadelta/iterations:0' shape=() dtype=int64_ref>
<tf.Variable 'Adadelta/lr:0' shape=() dtype=float32_ref>
0
1.0

或以其他方式获取优化器参数:

print(model.optimizer.get_config())
{'lr': 1.0, 'rho': 0.95, 'decay': 0.0, 'epsilon': 1e-07}

所以我的问题是keras是否在模型加载时重置优化器状态?

据此 https://github.com/keras-team/keras/blob/master/keras/engine/saving.py#L473 它应该保存模型的优化器状态。

这是保存优化器状态的实际代码: https://github.com/keras-team/keras/blob/613aeff37a721450d94906df1a3f3cc51e2299d4/keras/engine/saving.py#L132

优化器配置: https://github.com/keras-team/keras/blob/613aeff37a721450d94906df1a3f3cc51e2299d4/keras/engine/saving.py#L146

优化器权重: https://github.com/keras-team/keras/blob/613aeff37a721450d94906df1a3f3cc51e2299d4/keras/engine/saving.py#L157

更新:

model.optimizer.weights包含什么?

keras.__version__ 2.1.6

print('len(model.get_weights())', len(model.get_weights()))
w1 = model.get_weights()[0]
print('type(w1)', type(w1))
print('w1.shape', w1.shape)

len(model.get_weights()) 86
type(w1) <class 'numpy.ndarray'>
w1.shape (3, 3, 3, 16)

print('len(model.optimizer.get_weights())', len(model.optimizer.get_weights()))
w2 = model.optimizer.get_weights()[0]
print('type(w2)', type(w2))
print('w2.shape', w2.shape)

len(model.optimizer.get_weights()) 116
type(w2) <class 'numpy.ndarray'>
w2.shape (3, 3, 3, 16)

print('max abs diff w1-w2', np.max(np.abs(w1-w2)))
max abs diff w1-w2 0.8932746

1 个答案:

答案 0 :(得分:1)

它应该保存状态。加载时不会重置状态。

检查此问题的正确方法是使用model.optimizer.weights列表:

model = load_model(....)
loaded_optimizer_states = [K.eval(w) for w in model.optimizer.weights]

#resetting the optimizer
model.compile(optimizer='adadelta', ...)
reset_optimizer_states = [K.eval(w) for w in model.optimizer.weights]

for w1,w2 in zip(loaded_optimizer_states,reset_optimizer_states):
    print('equal?', (w1==w2).all())

现在,它不一定能保存我们想要的一切。例如,lr通常不是权重,而只是配置。将使用lr值对内部iterations进行内部计算。

但是您还可以在source code的优化程序的get_updates方法中看到:

  • SGD将迭代保存为权重:self.weights = [self.iterations] + moments
  • 但是Adadelta不:self.weights = accumulators + delta_accumulators

因此,尽管应该保存权重,但是您正在查看错误的变量,并且Adadelta似乎有一个错误代码。如果您将decayAdadelta一起使用,则可能应该手动保存并加载iterations或创建优化程序代码的自定义副本,然后在其中将iterations添加到{{1} }通过以下方式更改行:

weights

查看代码,看来self.weights = [self.iterations] + accumulators + delta_accumulators 是唯一实际保存SGD的代码,这似乎是保存/加载优化器状态的一般错误。

打开了这个问题:https://github.com/keras-team/keras/issues/13027

什么是iterations

它们是两件事:

  • model.optimizer.weights:模型的权重,即使您没有优化器也可以使模型正常工作(如果您只是想进行预测,可以不经编译就使用模型)
  • model.weightsmodel.optimizer.weights的状态。它们不一定与模型的权重有关,它们只是定义优化器在训练时应如何“更新”模型的权重。

现在,列表中的每个权重是什么?

这很大程度上取决于您所使用的优化程序。您可以查看源代码以了解每个优化器保存为状态的内容。

optimizer优化器具有SGD。这意味着SGD的状态将保存当前迭代(用于在出现衰减时定义当前self.weights = [self.iterations] + moments)和优化器的lr

moments是一个包含张量的列表,张量的形状与moments的列表相同。因为每个模型权重都有动量。

其他优化器使用更复杂的数学计算,并且可以具有更多的东西作为优化器权重。例如,model.get_weights()具有Adadeltaaccumulators。我不知道它们是什么,我们应该研究此优化器的数学公式。但这与delta_accumulators的同一行:优化器状态将定义训练期间如何更新模型的权重。它们的形状也可能与模型的权重相同,但要两倍。