如何正确设置keras密集输出层的大小

时间:2019-02-16 22:52:53

标签: python keras lstm

我正在尝试使用来自midi文件的数据在keras中运行RNN来生成音乐,但我在模型的最后一层遇到了一个问题,该错误出现在错误中:“检查目标时出错:预期的density_26为形状为(1,),但数组的形状为(110539,)”。当标签的大小为(110539)时,为什么网络希望输出层大小为1,密集层应为多少?我的代码如下:

input_data=[np.random.randint(0,110539) for i in range(32427)]
input_temp =[]
output_temp = []
for i in range(0,len(input_data)-seq_length,1):
    input_temp.append(input_data[i:i+seq_length])
    output_temp.append(input_data[i+seq_length])

sequences = len(output_temp)

x = np.reshape(input_temp,(sequences,seq_length,1))
x = x/classes
y = keras.utils.to_categorical(output_temp)
classes = y[0].shape[0]

model = Sequential()
model.add(CuDNNLSTM(512,input_shape=(seq_length,1),return_sequences=True))
model.add(Dropout(0.2))
model.add(CuDNNLSTM(512))
model.add(Dropout(0.2))
model.add(Dense(256,activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(classes,activation='softmax'))
opt = keras.optimizers.Adam(lr=1e-3,decay=1e-5)
model.compile(loss='sparse_categorical_crossentropy',optimizer=opt)
model.fit(x,y,epochs=3)

当我打印x.shape时,我得到(32327,100,1)并且y的尺寸是(32327,类)。感谢您的帮助

edit:model.summary()的输出

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cu_dnnlstm_1 (CuDNNLSTM)     (None, 100, 512)          1054720   
_________________________________________________________________
dropout_1 (Dropout)          (None, 100, 512)          0         
_________________________________________________________________
cu_dnnlstm_2 (CuDNNLSTM)     (None, 512)               2101248   
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328    
_________________________________________________________________
dropout_3 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 110539)            28408523  
=================================================================

0 个答案:

没有答案