火车正确堆叠自动编码器-Keras

时间:2018-09-07 10:58:02

标签: python tensorflow machine-learning keras deep-learning

我尝试在Keras(tf.keras)中构建一个堆栈式自动编码器。 堆积不是指。我为Keras找到的所有示例都在生成例如3个编码器层,3个解码器层,他们对其进行训练,并将其称为一天。但是,本文介绍的一种训练堆叠自动编码器(SAE)的正确方法似乎是:Stacked Denoising Autoencoders: Learning Useful Representations in a Deep Network with a Local Denoising Criterion

简而言之,应该对SAE进行分层训练,如下图所示。训练完第1层后,将其用作训练第2层的输入。应将重建损失与第1层而不是输入层进行比较。

这就是我的麻烦开始的地方。如何告诉Keras在哪些层上使用损失函数?

这是我的工作。由于Keras中不再存在自动编码器模块,因此我构建了第一个自动编码器,并在第二个自动编码器的第一层(共2层)中设置了其编码器的权重(trainable = False)。然后,当我训练它时,它显然会将重构的out_s2层与输入层in_s而不是第1层hid1进行比较。

# autoencoder layer 1
in_s = tf.keras.Input(shape=(input_size,))
noise = tf.keras.layers.Dropout(0.1)(in_s)
hid = tf.keras.layers.Dense(nodes[0], activation='relu')(noise)
out_s = tf.keras.layers.Dense(input_size, activation='sigmoid')(hid)

ae_1 = tf.keras.Model(in_s, out_s, name="ae_1")
ae_1.compile(optimizer='nadam', loss='binary_crossentropy', metrics=['acc'])

# autoencoder layer 2
hid1 = tf.keras.layers.Dense(nodes[0], activation='relu')(in_s)
noise = tf.keras.layers.Dropout(0.1)(hid1)
hid2 = tf.keras.layers.Dense(nodes[1], activation='relu')(noise)
out_s2 = tf.keras.layers.Dense(nodes[0], activation='sigmoid')(hid2)

ae_2 = tf.keras.Model(in_s, out_s2, name="ae_2")
ae_2.layers[0].set_weights(ae_1.layers[0].get_weights())
ae_2.layers[0].trainable = False

ae_2.compile(optimizer='nadam', loss='binary_crossentropy', metrics=['acc'])

该解决方案应该相当简单,但是我看不到它,也无法在线找到它。我该如何在Keras中做到这一点?

1 个答案:

答案 0 :(得分:0)

通过查看评论,这个问题似乎已经过时了。但我仍然会回答这个问题,因为这个问题中提到的用例不仅特定于自动编码器,而且可能对其他一些情况有所帮助。

所以,当你说“一层一层地训练整个网络”时,我宁愿将其解释为“在一个序列中训练一个单层的小型网络”。

看这个问题贴出的代码,好像OP已经搭建了小型网络。 但这两个网络都不是由一层组成。

这里的第二个自动编码器将第一个自动编码器的输入作为输入。但是,它实际上应该将第一个自动编码器的输出作为输入。

那么,你训练第一个自动编码器并在训练后收集它的预测。然后训练第二个自动编码器,它将第一个自动编码器的输出(预测)作为输入。

现在让我们关注这一部分:“第 1 层训练完成后,用作训练第 2 层的输入。重建损失应该与第 1 层而不是输入层进行比较。”

由于网络将第 1 层的输出(在 OP 的情况下为自动编码器 1)的输出作为输入,它将与它的输出进行比较。任务完成。

但要实现这一点,您需要编写问题中提供的代码中缺少的 model.fit(...) 行。

此外,如果您希望模型计算输入层的损失,您只需将 y 中的 model,fit(...) 参数替换为自动编码器 1 的输入。

简而言之,您只需要将这些自编码器解耦为一个单层的微型网络,然后根据需要训练它们。现在无需使用 trainable = False,否则您可以随意使用它。