将fit_generator与多个输入配合使用会在输出密集层产生错误

时间:2019-05-29 12:16:35

标签: python keras deep-learning

就我而言,我正在使用一组顺序特征以及非顺序特征来训练模型。以下是我的模型的架构

Sequential features -> LSTM -> Dense(1) --->>
                                             \
                                              \
                                               -- Dense -> Dense -> Dense(1) ->output
                                              /
                   Non-sequential features---/

我正在使用数据生成器来为连续数据生成批处理。在此,每个批次的批次大小都不同。对于一个批次,我将固定非顺序功能。以下是我的数据生成器。

def training_data_generator(raw_data):
    while True:
        for index, row in raw_data.iterrows():
            x_train, y_train = list(), list()
            feature1 = row['xxx']
            x_current_batch = []
            y_current_batch = []
            for j in range(yyy):
                x_current_batch.append(row['zz1'])
                y_current_batch.append(row['zz2'])
            x_train.append(x_current_batch)
            y_train.append(y_current_batch)
            x_train = array(x_train)
            y_train = array(y_train)

            yield [x_train, np.reshape(feature1,1)], y_train

注意:x_train和y_train的大小各不相同。

以下是我的模型实现。

seq_input = Input(shape=(None, 3))
lstm_layer = LSTM(50)(seq_input)
dense_layer1 = Dense(1)(lstm_layer)

non_seq_input = Input(shape=(1,))

hybrid_model = concatenate([dense_layer1, non_seq_input])

hidden1 = Dense(10, activation = 'relu')(hybrid_model)
hidden2 = Dense(10, activation='relu')(hidden1)

final_output = Dense(1, activation='sigmoid')(hidden2)

model = Model(inputs = [seq_input, non_seq_input], outputs = final_output)

model.compile(loss='mse',optimizer='adam')

model.fit_generator(training_data_generator(flatten), steps_per_epoch= 5017,
                              epochs = const.NUMBER_OF_EPOCHS, verbose=1)

我在输出密集层出现错误

ValueError: Error when checking target: 
expected dense_4 to have shape (1,) but got array with shape (4,)

我认为最后一层正在获取生成器的全部输出,而不是一个接一个地输出。

此问题的原因是什么。赞赏您对这个问题的见解。

1 个答案:

答案 0 :(得分:0)

输出给出一个大小为4的Dense层。由​​于您已将输出声明为大小为1的Dense层,因此会崩溃。

您可以做的是将输出密集图层更改为4。然后手动将其转换为一个值。

希望这能回答您的问题。