训练LSTM自动编码器会导致NaN / MSE超高损失

时间:2020-03-06 08:04:53

标签: python tensorflow keras deep-learning lstm

我正在尝试训练LSTM ae。 就像seq2seq模型一样,您输入一个信号以获得重构的信号序列。我正在使用一个序列,应该很容易。损失函数和指标是MSE。前一百个时期进展顺利。但是,经过一段时间后,我得到了MSE很高的MSE,有时甚至达到了NaN。我不知道是什么原因造成的。 您可以检查代码并给我提示吗? 该序列之前已进行归一化,因此处于[0,1]范围内,如何产生如此高的MSE错误? 这是我从训练集中得到的输入序列:

sequence1 = x_train[0][:128]

看起来像这样:

我从公共信号数据集中获取数据(128 * 1) 这是代码:(我是从keras博客修改的)

# lstm autoencoder recreate sequence
from numpy import array
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import RepeatVector
from keras.layers import TimeDistributed
from keras.utils import plot_model
from keras import regularizers

# define input sequence. sequence1 is only a one dimensional list
# reshape sequence1 input into [samples, timesteps, features]
n_in = len(sequence1)
sequence = sequence1.reshape((1, n_in, 1))
# define model
model = Sequential()
model.add(LSTM(1024, activation='relu', input_shape=(n_in,1)))
model.add(RepeatVector(n_in))
model.add(LSTM(1024, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(1)))
model.compile(optimizer='adam', loss='mse')
for epo in [50,100,1000,2000]:
   model.fit(sequence, sequence, epochs=epo)

前几个时期进展顺利。所有的损失约为0.003倍左右。然后突然变大了,到了很大数量的NaN一直都在上升。

2 个答案:

答案 0 :(得分:1)

进行反向传播时,可能会出现梯度值爆炸的问题。 尝试使用clipnorm和clipvalue参数来控制渐变裁剪:https://keras.io/optimizers/

或者,您使用的学习率是多少?我还将尝试将学习率降低10,100,1000,以检查您是否观察到相同的行为。

答案 1 :(得分:1)

'relu'是罪魁祸首-请参阅here。可能的解决方案:

  1. 将权重初始化为较小的值,例如keras.initializers.TruncatedNormal(mean=0.0, stddev=0.01)
  2. 剪辑权重(在初始化时,或通过kernel_constraintrecurrent_constraint等)
  3. 增加体重衰减
  4. 使用预热学习率方案(从低开始,逐渐增加)
  5. 使用'selu'激活更加稳定,类似于ReLU,并且在某些任务上比ReLU更好

由于您的训练在许多时期都保持稳定,因此 3 听起来是最有前途的,因为最终您的权重范数会变得太大,并且梯度会爆炸。通常,我建议将'relu'的权重标准保持在1附近;您可以使用以下功能监视l2规范。我还建议使用See RNN检查图层的激活和渐变。


def inspect_weights_l2(model, names='lstm', axis=-1):
    def _get_l2(w, axis=-1):
        axis = axis if axis != -1 else len(w.shape) - 1
        reduction_axes = tuple([ax for ax in range(len(w.shape)) if ax != axis])
        return np.sqrt(np.sum(np.square(w), axis=reduction_axes))

    def _print_layer_l2(layer, idx, axis=-1):
        W = layer.get_weights()
        l2_all = []
        txt = "{} "

        for w in W:
            txt += "{:.4f}, {:.4f} -- "
            l2 = _get_l2(w, axis)
            l2_all.extend([l2.max(), l2.mean()])
        txt = txt.rstrip(" -- ")

        print(txt.format(idx, *l2_all))

    names = [names] if isinstance(names, str) else names

    for idx, layer in enumerate(model.layers):
        if any([name in layer.name.lower() for name in names]):
            _print_layer_l2(layer, idx, axis=axis)