我正在尝试创建Keras LSTM(请注意,我是Keras中的LSTM和RNN的新手)。该神经网络应该接受4116个值的输入,并输出4116个值。这将执行288个时间步。我有27个这样的时间步长(我意识到这可能会导致过拟合;我有一个更大的数据集,但首先想仅使用27个训练示例来测试我的代码)。
训练数据存储在两个numpy数组x
和y
中。这些变量的形状为(27, 288, 4116)
。
我的代码:
datapoints = data.get.monthPoints(2, year)
x, y = datapoints[:-1], datapoints[1:]
del datapoints
input_shape = x.shape[1:]
output_shape = y.shape[1:]
checkpoint = ModelCheckpoint('model/files/alpha.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='auto', period=1)
early = EarlyStopping(monitor='val_loss', min_delta=0, patience=1, verbose=1, mode='auto')
model = Sequential()
model.add(LSTM(5488, activation='relu', input_shape=input_shape))
model.add(RepeatVector(output_shape))
model.add(LSTM(5488, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(output_shape)))
model.compile(loss='mse', optimizer='adam')
model.fit(x, y, epochs=100, batch_size=8, callbacks = [checkpoint, early])
运行程序时,出现以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'repeat_vector/stack_1' (op: 'Pack') with input shapes: [], [2], []
和
During handling of the above exception, another exception occurred:
ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'repeat_vector/stack_1' (op: 'Pack') with input shapes: [], [2], []
我还看到了其他一些类似的问题,例如this和this,但是他们没有提供解决我的问题的解决方案,或者解决方案不清楚。
我猜我的问题与我错误地构建网络或错误格式化数据有关。
任何见识都会令我感激。
谢谢。
答案 0 :(得分:1)
您可能希望重复第一LSTM层的输出,其重复次数与模型输出序列中的时间步长(即y
)相同。因此,应该是:
model.add(RepeatVector(y.shape[1]))
答案 1 :(得分:1)
您的代码中有两个问题。首先,在RepeatVector
中,您通过传递y.shape [1:]发送一个列表。在RepeatVector
中,您应该发送一个整数。
第二个问题在TimeDistributed
中。发送您希望重复第二维的次数。
因此您的代码应为:
repeat_vector_units = x.shape[1]
output_units = x.shape[2]
model = Sequential()
model.add(LSTM(5488, activation='relu', input_shape=input_shape))
model.add(RepeatVector(repeat_vector_units)) ### Change here
model.add(LSTM(5488, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(output_units))) #### Change here