错误的LSTM时间序列预测的输入大小与训练的输入大小不同

时间:2020-01-05 16:11:16

标签: python keras time-series lstm recurrent-neural-network

我正在与Bach chorales dataset合作。每个合唱的长度约为100-500个时间步长,每个时间步长包含4个整数(例如:[74、70、65、58]),其中每个整数对应于钢琴上的音符索引。

我正在尝试训练一个可以预测下一个步骤的模型(4 注意),并从合唱开始按时间顺序进行操作。

问题是什么:对于与模型经过训练的相同大小的输入,我得到了正确的输出,但对于不同大小的输入却得到了错误的输出。

到目前为止,我所做的事情:我使用了Keras的TimeseriesGenerator来生成输入和相应输出的序列:

generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
print(generator[0])

输出:

(array([[[74, 70, 65, 58],
        [74, 70, 65, 58],
        [74, 70, 65, 58]]]), array([[75, 70, 58, 55]]))

然后我训练了LSTM模型。我在input_shape中使用None来允许可变大小的输入。

n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')

# fit model
model.fit_generator(generator, epochs=500, validation_data=validation_generator)

我预测大小为3的输入似乎可行(因为它是针对长度为3的输入进行训练的):

# demonstrate prediction
x_input = dataX[5:8]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[8])
[[[75 70 58 55]
  [75 70 60 55]
  [75 70 60 55]]]
[[76.25768  68.525444 59.745518 53.799873]]
expected:  [77 69 62 50]

现在,我尝试预测输入长度为5的不同大小的输入,这不起作用。 测试样品的输出:

# demonstrate prediction
x_input = dataX[1:6]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[6])
[[[74 70 65 58]
  [74 70 65 58]
  [74 70 65 58]
  [75 70 58 55]
  [75 70 58 55]]]
[[227.16667 217.89767 213.62988 148.44817]]
expected:  [75 70 60 55]

该预测是完全错误的,它似乎正在做一些求和。对于为什么会发生这种情况以及如何解决它的任何意见/帮助,将不胜感激。

2 个答案:

答案 0 :(得分:1)

我可以为您提供三个模型无法学习的可能原因。

最后一个致密层

model.add(Dense(n_features))

这可能是您模型中的主要罪魁祸首(但我建议全部解决)。分类模型的最后一层需要为softmax层。因此,只需将其更改为

model.add(Dense(n_features, activation='softmax`))

丢失功能

通常,crossentropymse更能解决分类问题。所以尝试,

model.compile(optimizer='adam', loss='categorical_crossentropy')

在LSTM中激活

LSTM使用tanh作为激​​活。除非您有充分的理由将其更改为relu,否则不要这样做,因为当激活函数更改时,LSTM不会输出与常规前馈层相同的行为。

答案 1 :(得分:0)

我建议x_input的长度保持3会更好 以下是我的测试代码:

import sys
from keras.models import Sequential
from keras.layers import Dense,Activation,LSTM
from keras.preprocessing.sequence import TimeseriesGenerator
import numpy as np
import logger
logger.logger_initialize('LOGGER.log')


def bc_pitches():
    a = open('chorales.lisp', 'r')

    #parse the input as vectors and store vectors

    def obtainNum(elemSt):
        a = elemSt.split(" ")
        return int(a[1])

    bookOfLists = []

    for i in range(210):
        counter = 0
        gun = a.readline()
        if (len(gun) <= 1): #for /n accommodation
            continue
        else:
            while (gun[counter:(counter+2)] != "(("):
                counter += 1
            tribo = gun[(counter+2):(len(gun)-4)]
            stringArr = tribo.split("))((") #separates each vector into an element
            lister = [x.split(") (") for x in stringArr]
            #lister = map(lambda x : x.split(") ("), stringArr) #each vector becomes
            #a list of component elements so lister is a list of lists
            lister2 = [[obtainNum(each) for each in x] for x in lister]
            #lister2 = map(lambda x : map(obtainNum, x), lister)
            bookOfLists.append(lister2)
    pitches=np.zeros([100,500],dtype=np.int32)
    for i in range(len(bookOfLists)):
        for j in range(len(bookOfLists[i])):
            for t in range(bookOfLists[i][j][0],bookOfLists[i][j][0]+bookOfLists[i][j][2]):
                try:
                    pitches[i][t]=bookOfLists[i][j][1]
                except:
                    print(i,j,t)
                    sys.exit()
    return pitches

pitches=bc_pitches()
dataX=dataY=(pitches[:4,:].T)[:150]
generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
for i in range(len(generator)):
    logger.info(i,generator[i])

validation_dataX=validation_dataY=(pitches[:4,:].T)[150:]
validation_generator = TimeseriesGenerator(validation_dataX, validation_dataY, length=3, batch_size=1)


n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')

# fit model
model.fit_generator(generator, epochs=50, validation_data=validation_generator)


# demonstrate prediction
x_input = (pitches[:4,:].T)[155:158]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[158])


# demonstrate prediction
x_input = (pitches[:4,:].T)[151:156]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[156])

for i in range(10):
    yhat = model.predict(validation_generator[i][0], verbose=0)
    logger.info(i,yhat)
    logger.info('expected: ', validation_generator[i][1])

和结果:

...
    100 (array([[[72, 73, 69, 73],
            [72, 73, 69, 73],
            [72, 73, 69, 73]]]),
     array([[72, 73, 69, 73]])) 
    101 (array([[[72, 73, 69, 73],
            [72, 73, 69, 73],
            [72, 73, 69, 73]]]),
     array([[74, 71, 71, 71]])) 
    102 (array([[[72, 73, 69, 73],
            [72, 73, 69, 73],
            [74, 71, 71, 71]]]),
     array([[74, 71, 71, 71]])) 
    103 (array([[[72, 73, 69, 73],
            [74, 71, 71, 71],
            [74, 71, 71, 71]]]),
     array([[74, 71, 71, 71]])) 
    104 (array([[[74, 71, 71, 71],
            [74, 71, 71, 71],
            [74, 71, 71, 71]]]),
     array([[74, 71, 71, 71]])) 
    105 (array([[[74, 71, 71, 71],
            [74, 71, 71, 71],
            [74, 71, 71, 71]]]),
     array([[74, 73, 67, 71]])) 
    106 (array([[[74, 71, 71, 71],
            [74, 71, 71, 71],
            [74, 73, 67, 71]]]),
     array([[74, 73, 67, 71]])) 
    107 (array([[[74, 71, 71, 71],
            [74, 73, 67, 71],
            [74, 73, 67, 71]]]),
     array([[74, 73, 67, 71]])) 
    108 (array([[[74, 73, 67, 71],
            [74, 73, 67, 71],
            [74, 73, 67, 71]]]),
     array([[74, 73, 67, 71]])) 
    109 (array([[[74, 73, 67, 71],
            [74, 73, 67, 71],
            [74, 73, 67, 71]]]),
     array([[74, 74, 69, 76]])) 
    110 (array([[[74, 73, 67, 71],
            [74, 73, 67, 71],
            [74, 74, 69, 76]]]),
     array([[74, 74, 69, 76]])) 
    111 (array([[[74, 73, 67, 71],
            [74, 74, 69, 76],
            [74, 74, 69, 76]]]),
     array([[72, 74, 71, 76]])) 
    112 (array([[[74, 74, 69, 76],
            [74, 74, 69, 76],
            [72, 74, 71, 76]]]),
     array([[72, 74, 71, 76]])) 
    113 (array([[[74, 74, 69, 76],
            [72, 74, 71, 76],
            [72, 74, 71, 76]]]),
     array([[71, 73, 72, 71]])) 
    114 (array([[[72, 74, 71, 76],
            [72, 74, 71, 76],
            [71, 73, 72, 71]]]),
     array([[71, 73, 72, 71]])) 
    115 (array([[[72, 74, 71, 76],
            [71, 73, 72, 71],
            [71, 73, 72, 71]]]),
     array([[71, 73, 72, 71]])) 
    116 (array([[[71, 73, 72, 71],
            [71, 73, 72, 71],
            [71, 73, 72, 71]]]),
     array([[71, 73, 72, 71]])) 
    117 (array([[[71, 73, 72, 71],
            [71, 73, 72, 71],
            [71, 73, 72, 71]]]),
     array([[69, 71, 71, 73]])) 
    118 (array([[[71, 73, 72, 71],
            [71, 73, 72, 71],
            [69, 71, 71, 73]]]),
     array([[69, 71, 71, 73]])) 
    119 (array([[[71, 73, 72, 71],
            [69, 71, 71, 73],
            [69, 71, 71, 73]]]),
     array([[69, 71, 71, 73]]))
    120 (array([[[69, 71, 71, 73],
            [69, 71, 71, 73],
            [69, 71, 71, 73]]]),
     array([[69, 71, 71, 73]]))
    121 (array([[[69, 71, 71, 73],
            [69, 71, 71, 73],
            [69, 71, 71, 73]]]),
     array([[69, 70, 72, 68]]))
    122 (array([[[69, 71, 71, 73],
            [69, 71, 71, 73],
            [69, 70, 72, 68]]]),
     array([[69, 70, 72, 68]]))
    123 (array([[[69, 71, 71, 73],
            [69, 70, 72, 68],
            [69, 70, 72, 68]]]),
     array([[69, 70, 71, 69]]))
    124 (array([[[69, 70, 72, 68],
            [69, 70, 72, 68],
            [69, 70, 71, 69]]]),
     array([[69, 70, 71, 69]]))
    125 (array([[[69, 70, 72, 68],
            [69, 70, 71, 69],
            [69, 70, 71, 69]]]),
     array([[67, 71, 69, 71]]))
    126 (array([[[69, 70, 71, 69],
            [69, 70, 71, 69],
            [67, 71, 69, 71]]]),
     array([[67, 71, 69, 71]]))
    127 (array([[[69, 70, 71, 69],
            [67, 71, 69, 71],
            [67, 71, 69, 71]]]),
     array([[67, 71, 69, 71]]))
    128 (array([[[67, 71, 69, 71],
            [67, 71, 69, 71],
            [67, 71, 69, 71]]]),
     array([[67, 71, 69, 71]]))
    129 (array([[[67, 71, 69, 71],
            [67, 71, 69, 71],
            [67, 71, 69, 71]]]),
     array([[71, 71, 68, 69]]))
    130 (array([[[67, 71, 69, 71],
            [67, 71, 69, 71],
            [71, 71, 68, 69]]]),
     array([[71, 71, 68, 69]]))
    131 (array([[[67, 71, 69, 71],
            [71, 71, 68, 69],
            [71, 71, 68, 69]]]),
     array([[71, 71, 68, 69]]))
    132 (array([[[71, 71, 68, 69],
            [71, 71, 68, 69],
            [71, 71, 68, 69]]]),
     array([[71, 71, 68, 69]]))
    133 (array([[[71, 71, 68, 69],
            [71, 71, 68, 69],
            [71, 71, 68, 69]]]),
     array([[71, 71, 69, 68]]))
    134 (array([[[71, 71, 68, 69],
            [71, 71, 68, 69],
            [71, 71, 69, 68]]]),
     array([[71, 71, 69, 68]]))
    135 (array([[[71, 71, 68, 69],
            [71, 71, 69, 68],
            [71, 71, 69, 68]]]),
     array([[71, 71, 69, 68]]))
    136 (array([[[71, 71, 69, 68],
            [71, 71, 69, 68],
            [71, 71, 69, 68]]]),
     array([[71, 71, 69, 68]]))
    137 (array([[[71, 71, 69, 68],
            [71, 71, 69, 68],
            [71, 71, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    138 (array([[[71, 71, 69, 68],
            [71, 71, 69, 68],
            [72, 64, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    139 (array([[[71, 71, 69, 68],
            [72, 64, 69, 68],
            [72, 64, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    140 (array([[[72, 64, 69, 68],
            [72, 64, 69, 68],
            [72, 64, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    141 (array([[[72, 64, 69, 68],
            [72, 64, 69, 68],
            [72, 64, 69, 68]]]),
     array([[74, 69, 76, 66]]))
    142 (array([[[72, 64, 69, 68],
            [72, 64, 69, 68],
            [74, 69, 76, 66]]]),
     array([[74, 69, 76, 66]]))
    143 (array([[[72, 64, 69, 68],
            [74, 69, 76, 66],
            [74, 69, 76, 66]]]),
     array([[74, 69, 76, 66]]))
    144 (array([[[74, 69, 76, 66],
            [74, 69, 76, 66],
            [74, 69, 76, 66]]]),
     array([[74, 69, 76, 66]]))
    145 (array([[[74, 69, 76, 66],
            [74, 69, 76, 66],
            [74, 69, 76, 66]]]),
     array([[74, 71, 72, 69]]))
    146 (array([[[74, 69, 76, 66],
            [74, 69, 76, 66],
            [74, 71, 72, 69]]]),
     array([[74, 71, 72, 69]]))
    Epoch 1/50
    147/147 [==============================] - 2s 16ms/step - loss: 514.8802 - val_l
    oss: 0.0082
    Epoch 2/50
    147/147 [==============================] - 2s 11ms/step - loss: 51.5768 - val_lo
    ss: 0.0249
    Epoch 3/50
    147/147 [==============================] - 2s 11ms/step - loss: 71.6900 - val_lo
    ss: 0.0464
    Epoch 4/50
    147/147 [==============================] - 2s 10ms/step - loss: 47.4575 - val_lo
    ss: 0.1303
    Epoch 5/50
    147/147 [==============================] - 2s 10ms/step - loss: 52.6841 - val_lo
    ss: 0.5772
    Epoch 6/50
    147/147 [==============================] - 2s 11ms/step - loss: 47.3059 - val_lo
    ss: 5.2535
    Epoch 7/50
    147/147 [==============================] - 2s 11ms/step - loss: 43.6491 - val_lo
    ss: 41.2008
    Epoch 8/50
    147/147 [==============================] - 2s 11ms/step - loss: 37.8593 - val_lo
    ss: 28.5831
    Epoch 9/50
    147/147 [==============================] - 2s 11ms/step - loss: 40.8553 - val_lo
    ss: 41.5958
    Epoch 10/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.5995 - val_lo
    ss: 57.3419
    Epoch 11/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.2054 - val_lo
    ss: 38.9516
    Epoch 12/50
    147/147 [==============================] - 2s 11ms/step - loss: 36.9247 - val_lo
    ss: 38.1881
    Epoch 13/50
    147/147 [==============================] - 2s 10ms/step - loss: 34.5922 - val_lo
    ss: 49.7601
    Epoch 14/50
    147/147 [==============================] - 2s 11ms/step - loss: 38.1668 - val_lo
    ss: 46.0043
    Epoch 15/50
    147/147 [==============================] - 2s 10ms/step - loss: 35.4724 - val_lo
    ss: 39.1485
    Epoch 16/50
    147/147 [==============================] - 2s 11ms/step - loss: 35.7787 - val_lo
    ss: 38.2263
    Epoch 17/50
    147/147 [==============================] - 2s 11ms/step - loss: 32.5241 - val_lo
    ss: 38.0783
    Epoch 18/50
    147/147 [==============================] - 2s 11ms/step - loss: 35.1693 - val_lo
    ss: 35.3403
    Epoch 19/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.5822 - val_lo
    ss: 28.0546
    Epoch 20/50
    147/147 [==============================] - 2s 11ms/step - loss: 32.7388 - val_lo
    ss: 37.5600
    Epoch 21/50
    147/147 [==============================] - 2s 11ms/step - loss: 36.7384 - val_lo
    ss: 19.3809
    Epoch 22/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.0202 - val_lo
    ss: 38.0124
    Epoch 23/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.7241 - val_lo
    ss: 36.0455
    Epoch 24/50
    147/147 [==============================] - 2s 10ms/step - loss: 33.6021 - val_lo
    ss: 19.4785
    Epoch 25/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.5922 - val_lo
    ss: 37.5662
    Epoch 26/50
    147/147 [==============================] - 2s 10ms/step - loss: 31.7600 - val_lo
    ss: 25.8877
    Epoch 27/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.0494 - val_lo
    ss: 25.5513
    Epoch 28/50
    147/147 [==============================] - 2s 11ms/step - loss: 32.7150 - val_lo
    ss: 22.6177
    Epoch 29/50
    147/147 [==============================] - 2s 11ms/step - loss: 30.3998 - val_lo
    ss: 26.8450
    Epoch 30/50
    147/147 [==============================] - 2s 10ms/step - loss: 30.3076 - val_lo
    ss: 42.8708
    Epoch 31/50
    147/147 [==============================] - 2s 11ms/step - loss: 30.6752 - val_lo
    ss: 32.9248
    Epoch 32/50
    147/147 [==============================] - 2s 10ms/step - loss: 29.2235 - val_lo
    ss: 33.0209
    Epoch 33/50
    147/147 [==============================] - 2s 11ms/step - loss: 30.7826 - val_lo
    ss: 21.4303
    Epoch 34/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.5795 - val_lo
    ss: 28.7224
    Epoch 35/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.2187 - val_lo
    ss: 19.5436
    Epoch 36/50
    147/147 [==============================] - 2s 10ms/step - loss: 28.8158 - val_lo
    ss: 23.3435
    Epoch 37/50
    147/147 [==============================] - 2s 10ms/step - loss: 27.8942 - val_lo
    ss: 29.7689
    Epoch 38/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.8379 - val_lo
    ss: 19.7113
    Epoch 39/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.4185 - val_lo
    ss: 30.7159
    Epoch 40/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.2826 - val_lo
    ss: 22.0266
    Epoch 41/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.3911 - val_lo
    ss: 22.6929
    Epoch 42/50
    147/147 [==============================] - 2s 10ms/step - loss: 28.0742 - val_lo
    ss: 16.1369
    Epoch 43/50
    147/147 [==============================] - 2s 11ms/step - loss: 27.4483 - val_lo
    ss: 19.0667
    Epoch 44/50
    147/147 [==============================] - 2s 11ms/step - loss: 27.6157 - val_lo
    ss: 15.3852
    Epoch 45/50
    147/147 [==============================] - 2s 11ms/step - loss: 27.9996 - val_lo
    ss: 21.4107
    Epoch 46/50
    147/147 [==============================] - 2s 11ms/step - loss: 28.4632 - val_lo
    ss: 17.0626
    Epoch 47/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.0796 - val_lo
    ss: 21.7797
    Epoch 48/50
    147/147 [==============================] - 2s 10ms/step - loss: 28.2646 - val_lo
    ss: 21.8080
    Epoch 49/50
    147/147 [==============================] - 2s 11ms/step - loss: 28.7243 - val_lo
    ss: 18.9899
    Epoch 50/50
    147/147 [==============================] - 2s 11ms/step - loss: 28.2579 - val_lo
    ss: 28.6534
    [[[72 73 74 68]
      [71 74 76 66]
      [71 74 76 66]]]
    [[72.415985 69.27797  71.99651  69.86983 ]]
    expected:  [71 74 76 66]
    [[[74 71 72 69]
      [72 73 74 68]
      [72 73 74 68]
      [72 73 74 68]
      [72 73 74 68]]]
    [[153.16042 179.3388  158.57655 169.93341]]
    expected:  [71 74 76 66]
    0 [[73.17023 69.77195 71.62949 71.44139]]
    expected:  [[72 73 74 68]]
    1 [[72.80142  69.71678  71.557175 71.15702 ]]
    expected:  [[72 73 74 68]]
    2 [[72.39997  69.51012  71.5443   70.574905]]
    expected:  [[72 73 74 68]]
    3 [[72.39997  69.51012  71.5443   70.574905]]
    expected:  [[71 74 76 66]]
    4 [[72.51985  69.45031  71.813896 70.3402  ]]
    expected:  [[71 74 76 66]]
    5 [[72.415985 69.27797  71.99651  69.86983 ]]
    expected:  [[71 74 76 66]]
    6 [[72.11394  68.977165 72.128334 69.17176 ]]
    expected:  [[71 74 76 66]]
    7 [[72.11394  68.977165 72.128334 69.17176 ]]
    expected:  [[71 76 74 61]]
    8 [[72.221664 69.22221  71.957596 68.933846]]
    expected:  [[71 76 74 61]]
    9 [[72.15421  69.480225 71.38563  68.43072 ]]
    expected:  [[71 76 74 61]]

    (Keras) D:\programs_data\Keras>