我正在Keras尝试LSTM,几乎没有运气。在某些时刻,我决定缩回最基本的问题,以便最终取得一些积极成果 然而,即使是最简单的问题,我发现Keras无法收敛,而Tensorflow中相同问题的实现给出了稳定的结果。
我不愿意在没有理解为什么Keras在我尝试的任何问题上不断分歧的情况下切换到Tensorflow。
我的问题是延迟回声的多对多序列预测,例如:
蓝线是网络输入序列,红色虚线是预期输出
实验的灵感来自于这个repo,并且也可以从中创建可行的Tensorflow解决方案。
我的代码的相关摘录如下,我的最小可重复示例的完整版可用here。
Keras模特:
model = Sequential()
model.add(LSTM(n_hidden,
input_shape=(n_steps, n_input),
return_sequences=True))
model.add(TimeDistributed(Dense(n_input, activation='linear')))
model.compile(loss=custom_loss,
optimizer=keras.optimizers.Adam(lr=learning_rate),
metrics=[])
Tensorflow模型:
x = tf.placeholder(tf.float32, [None, n_steps, n_input])
y = tf.placeholder(tf.float32, [None, n_steps])
weights = {
'out': tf.Variable(tf.random_normal([n_hidden, n_steps], seed = SEED))
}
biases = {
'out': tf.Variable(tf.random_normal([n_steps], seed = SEED))
}
lstm = rnn.LSTMCell(n_hidden, forget_bias=1.0)
outputs, states = tf.nn.dynamic_rnn(lstm, inputs=x,
dtype=tf.float32,
time_major=False)
h = tf.transpose(outputs, [1, 0, 2])
pred = tf.nn.bias_add(tf.matmul(h[-1], weights['out']), biases['out'])
individual_losses = tf.reduce_sum(tf.squared_difference(pred, y),
reduction_indices=1)
loss = tf.reduce_mean(individual_losses)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) \
.minimize(loss)
我声称代码的其他部分(data_generation
,training
)完全相同。但是,与Keras的学习进展很早就会停滞不前并产生令人不满意的预测。图书馆和示例预测的logloss
图表如下:
Keras训练模型的Logloss:
从图表中读取并不容易,但Tensorflow达到
target_loss=0.15
并在大约10k批次之后提前停止。但是Keras只用了loss
只有1.5
的所有13k批次。在Keras以10万批次运行的单独实验中,1.0
周围没有进一步停滞。
下图包含:黑线 - 模型输入信号,绿色虚线 - 地面实况输出,红线 - 采集模型输出。
答案 0 :(得分:0)
好的,我设法解决了这个问题。 Keras的实施现在也稳步收敛到一个合理的解决方案:
这些模型实际上并不完全相同。您可以格外检查问题中的Tensorflow
模型版本,并自行验证下面列出了实际的Keras
等效项,并且不是问题中所述的内容:
model = Sequential()
model.add(LSTM(n_hidden,
input_shape=(n_steps, n_input),
return_sequences=False))
model.add(Dense(n_steps, input_shape=(n_hidden,), activation='linear'))
model.compile(loss=custom_loss,
optimizer=keras.optimizers.Adam(lr=learning_rate),
metrics=[])
我会详细说明。这里可行的解决方案使用LSTM吐出的最后一列大小为n_hidden
的列作为中间激活,然后输入Dense
层。
因此,在某种程度上,这里的实际预测是由常规感知器进行的。
一个额外的删除注释 - 原始Keras
解决方案中的错误来源已经从问题附带的推理示例中显而易见。我们看到那里的早期时间戳完全失败,而后来的时间戳接近完美。这些早期的时间戳对应于LSTM刚刚在新窗口初始化并且无法上下文时的状态。