我正在尝试构建多元(100、20、23),多输出[(100、20、1),(100、20、1),(100、20、1)] sequence2sequence编码器/解码器。互联网上没有很多这样的例子。但是,我发现这个tutorial是关于带有多变量输入序列和单个输出序列的编码器-解码器LSTM模型的,我想将其概括为多输出。我尝试了下面的代码,但发现值错误w.r.t.到不同的时间分布层。请参阅下面的追溯。当我连接输出(100、20、3)时,模型会正确编译并拟合
有什么想法可以解决此问题,使其可以单独输出数据吗?
from keras.layers import *
from keras import Model
n_features = 23
hl2 = 150
window_len = 20
epochs = 5
batch_size = 128
features = np.random.rand(100, 20, 23)
targ1 = np.random.rand(100, 20, 1)
targ2 = np.random.rand(100, 20, 1)
targ3 = np.random.rand(100, 20, 1)
input_ = Input(shape=(window_len, n_features))
encoder_LSTM = LSTM(hl2, kernel_initializer='glorot_normal')(input_)
rep_vec = RepeatVector(window_len)(encoder_LSTM)
decoder_LSTM = LSTM(hl2, return_sequences=True)(rep_vec)
time_wrapper = TimeDistributed(Dense(int(hl2/2), activation='relu'))(decoder_LSTM)
out1 = TimeDistributed(
Dense(3, activation='linear', name='out1',
kernel_initializer='glorot_normal'))(time_wrapper)
out2 = TimeDistributed(
Dense(1, activation='linear', name='out2',
kernel_initializer='glorot_normal')
)(time_wrapper)
out3 = TimeDistributed(
Dense(1, activation='linear', name='out3',
kernel_initializer='glorot_normal')
)(time_wrapper)
model = Model(inputs=input_, outputs=[out1, out2, out3])
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(
features,
{
'out1': out3,
'out2': out3,
'out3': out3,
},
epochs=epochs,
batch_size=batch_size,
verbose=0,
shuffle='batch'
)
请参阅下面的引用
Traceback (most recent call last):
File "C:/Users/Master Tk/PycharmProjects/FPL/encoder.py", line 49, in <module>
shuffle='batch'
File "C:\Miniconda3\lib\site-packages\keras\engine\training.py", line 952, in fit
batch_size=batch_size)
File "C:\Miniconda3\lib\site-packages\keras\engine\training.py", line 789, in _standardize_user_data
exception_prefix='target')
File "C:\Miniconda3\lib\site-packages\keras\engine\training_utils.py", line 78, in standardize_input_data
'for each key in: ' + str(names))
ValueError: No data provided for "time_distributed_2". Need data for each key in: ['time_distributed_2', 'time_distributed_3', 'time_distributed_4']