从2个自动编码器中提取特征并将其输入MLP

时间:2018-06-02 15:48:39

标签: python neural-network keras autoencoder

据我所知,从自动编码器中提取的特征可以输入到mlp中进行分类或回归。这是我之前做的事情 但是,如果我有2个自动编码器怎么办?我可以从2个自动编码器的瓶颈层中提取特征并将它们输入到基于这些特征进行分类的mlp中吗?如果是,那怎么样?我不知道如何连接这两个功能集。我尝试使用numpy.hstack(),它给了我不可用的切片'错误,然而,使用tf.concat()给我的错误'模型的输入张量必须是Keras张量。'两个自动编码器的瓶颈层各有一个尺寸(无,100)。所以,基本上,如果我将它们水平堆叠,我应该得到一个(无,200)。 mlp的隐藏层可能包含一些(num_hidden = 100)神经元。有人可以帮忙吗?

x1 = autoencoder1.get_layer('encoder2').output
x2 = autoencoder2.get_layer('encoder2').output

#inp = np.hstack((x1, x2))
inp = tf.concat([x1, x2], 1)
x = tf.concat([x1, x2], 1)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
y = Dense(1, activation='sigmoid', name='prediction')(h)
mymlp = Model(inputs=inp, outputs=y)

# Compile model
mymlp.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train model
mymlp.fit(x_train, y_train, epochs=20, batch_size=8)

根据@ twolffpiggott的建议更新:

from keras.layers import Input, Dense, Dropout
from keras import layers
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import numpy as np

x1 = Data1
x2 = Data2
y = Data3

num_neurons1 = x1.shape[1]
num_neurons2 = x2.shape[1]

# Train-test split
x1_train, x1_test, x2_train, x2_test, y_train, y_test = train_test_split(x1, x2, y, test_size=0.2)

# scale data within [0-1] range
scalar = MinMaxScaler()
x1_train = scalar.fit_transform(x1_train)
x1_test = scalar.transform(x1_test)

x2_train = scalar.fit_transform(x2_train)
x2_test = scalar.transform(x2_test)

x_train = np.concatenate([x1_train, x2_train], axis =-1)
x_test = np.concatenate([x1_test, x2_test], axis =-1)

# Auto-encoder1

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons1,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded1 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded1)
decoded = Dense(num_neurons1, activation='sigmoid', name='decoder2')(decoded)

# this model maps an input to its reconstruction
autoencoder1 = Model(inputs=input_data, outputs=decoded)
autoencoder1.compile(optimizer='sgd', loss='mse')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    

# training
autoencoder1.fit(x1_train, x1_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x1_test, x1_test))

# Auto-encoder2

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons2,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded2 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded2)
decoded = Dense(num_neurons2, activation='sigmoid', name='decoder2')(decoded)


# this model maps an input to its reconstruction
autoencoder2 = Model(inputs=input_data, outputs=decoded)
autoencoder2.compile(optimizer='sgd', loss='mse')

# training
autoencoder2.fit(x2_train, x2_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x2_test, x2_test))

# MLP

num_hidden = 100

encoded1.trainable = False
encoded2.trainable = False

encoded1 = autoencoder1(autoencoder1.inputs)
encoded2 = autoencoder2(autoencoder2.inputs)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
x = Dropout(0.2)(concatenated)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

# Compile model
myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit(x_train, y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict(x_test)

给我一​​个错误:不可用的类型:' list'从行: myMLP = Model(inputs = [autoencoder1.inputs,autoencoder2.inputs],outputs = y)

2 个答案:

答案 0 :(得分:2)

问题在于你将numpy数组与keras张量混合在一起。这不可能。

有两种方法。

  • 预测每个自动编码器中的numpy数组,连接数组,将它们发送到第三个模型
  • 连接所有型号,可能使自动编码器无法控制,适合每个自动编码器的一个输入。

就个人而言,我是第一个去的。 (假设自动编码器已经过培训,不需要更改)。

第一种方法

numpyOutputFromAuto1 = autoencoder1.predict(numpyInputs1)    
numpyOutputFromAuto2 = autoencoder2.predict(numpyInputs2)

inputDataForThird = np.concatenate([numpyOutputFromAuto1,numpyOutputFromAuto2],axis=-1)

inputTensorForMlp = Input(inputsForThird.shape[1:])
h = Dense(num_hidden, activation='relu', name='hidden')(inputTensorForMlp)
y = Dense(1, activation='sigmoid', name='prediction')(h)

mymlp = Model(inputs=inputTensorForMlp, outputs=y)

....
mymlp.fit(inputDataForThird ,someY)

第二种方法

这有点复杂,起初我没有太多理由去做这件事。 (但当然可能存在一个不错的选择)

现在我们完全忘记了numpy并且与keras张量合作。

单独创建mlp(如果稍后在没有自动编码器的情况下使用它,那就很好):

inputTensorForMlp = Input(input_shape_compatible_with_concatenated_encoder_outputs)
x = Dropout(0.2)(inputTensorForMlp)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

我们可能想要自动编码器的瓶颈功能,对吧?如果您碰巧使用以下方法正确创建自动编码器:编码器模型,解码器模型,加入两者,那么它更容易使用编码器模型。否则:

encodedOutput1 = autoencoder1.layers[bottleneckLayer].outputs #or encoder1.outputs
encodedOutput2 = autoencoder1.layers[bottleneckLayer].outputs #or encoder2.outputs

创建已加入的模型。连接必须使用keras层(我们正在使用keras张量):

concatenated = Concatenate()([encodedOutput1,encodedOutput2])
output = myMLP(concatenated)

joinedModel = Model([autoencoder1.input,autoencoder2.input],output)

答案 1 :(得分:0)

我也会选择丹尼尔的第一种方法(为了简单和效率),但如果你对第二种方法感兴趣的话。例如,如果您对端到端运行网络感兴趣,可以这样做:

input_data1

主要修改

  1. 两个自动编码器的权重应该是端到端冻结的(否则,随机初始化的MLP的早期阶梯更新可能会导致大部分学习失败)。
  2. 应将自动编码器输入图层分配给每个自动编码器(而不是input_data2两者)的变量input_dataautoencoder1.inputs。即使unhashable type: list返回tf张量,这也是[input_data1, input_data2]例外的来源,替换为x1_train可以解决问题。
  3. 在为端到端模型拟合MLP时,输入应该是x2_train和{{1}}的列表,而不是连接的输入。预测时也一样。
相关问题