我正在与Bach chorales dataset合作。每个合唱的长度约为100-500个时间步长,每个时间步长包含4个整数(例如:[74、70、65、58]),其中每个整数对应于钢琴上的音符索引。
我正在尝试训练一个可以预测下一个步骤的模型(4 注意),并从合唱开始按时间顺序进行操作。
问题是什么:对于与模型经过训练的相同大小的输入,我得到了正确的输出,但对于不同大小的输入却得到了错误的输出。
到目前为止,我所做的事情:我使用了Keras的TimeseriesGenerator来生成输入和相应输出的序列:
generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
print(generator[0])
输出:
(array([[[74, 70, 65, 58],
[74, 70, 65, 58],
[74, 70, 65, 58]]]), array([[75, 70, 58, 55]]))
然后我训练了LSTM模型。我在input_shape中使用None
来允许可变大小的输入。
n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit_generator(generator, epochs=500, validation_data=validation_generator)
我预测大小为3的输入似乎可行(因为它是针对长度为3的输入进行训练的):
# demonstrate prediction
x_input = dataX[5:8]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[8])
[[[75 70 58 55]
[75 70 60 55]
[75 70 60 55]]]
[[76.25768 68.525444 59.745518 53.799873]]
expected: [77 69 62 50]
现在,我尝试预测输入长度为5的不同大小的输入,这不起作用。 测试样品的输出:
# demonstrate prediction
x_input = dataX[1:6]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[6])
[[[74 70 65 58]
[74 70 65 58]
[74 70 65 58]
[75 70 58 55]
[75 70 58 55]]]
[[227.16667 217.89767 213.62988 148.44817]]
expected: [75 70 60 55]
该预测是完全错误的,它似乎正在做一些求和。对于为什么会发生这种情况以及如何解决它的任何意见/帮助,将不胜感激。
答案 0 :(得分:1)
我可以为您提供三个模型无法学习的可能原因。
model.add(Dense(n_features))
这可能是您模型中的主要罪魁祸首(但我建议全部解决)。分类模型的最后一层需要为softmax
层。因此,只需将其更改为
model.add(Dense(n_features, activation='softmax`))
通常,crossentropy
比mse
更能解决分类问题。所以尝试,
model.compile(optimizer='adam', loss='categorical_crossentropy')
LSTM使用tanh
作为激活。除非您有充分的理由将其更改为relu
,否则不要这样做,因为当激活函数更改时,LSTM不会输出与常规前馈层相同的行为。
答案 1 :(得分:0)
我建议x_input的长度保持3会更好 以下是我的测试代码:
import sys
from keras.models import Sequential
from keras.layers import Dense,Activation,LSTM
from keras.preprocessing.sequence import TimeseriesGenerator
import numpy as np
import logger
logger.logger_initialize('LOGGER.log')
def bc_pitches():
a = open('chorales.lisp', 'r')
#parse the input as vectors and store vectors
def obtainNum(elemSt):
a = elemSt.split(" ")
return int(a[1])
bookOfLists = []
for i in range(210):
counter = 0
gun = a.readline()
if (len(gun) <= 1): #for /n accommodation
continue
else:
while (gun[counter:(counter+2)] != "(("):
counter += 1
tribo = gun[(counter+2):(len(gun)-4)]
stringArr = tribo.split("))((") #separates each vector into an element
lister = [x.split(") (") for x in stringArr]
#lister = map(lambda x : x.split(") ("), stringArr) #each vector becomes
#a list of component elements so lister is a list of lists
lister2 = [[obtainNum(each) for each in x] for x in lister]
#lister2 = map(lambda x : map(obtainNum, x), lister)
bookOfLists.append(lister2)
pitches=np.zeros([100,500],dtype=np.int32)
for i in range(len(bookOfLists)):
for j in range(len(bookOfLists[i])):
for t in range(bookOfLists[i][j][0],bookOfLists[i][j][0]+bookOfLists[i][j][2]):
try:
pitches[i][t]=bookOfLists[i][j][1]
except:
print(i,j,t)
sys.exit()
return pitches
pitches=bc_pitches()
dataX=dataY=(pitches[:4,:].T)[:150]
generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
for i in range(len(generator)):
logger.info(i,generator[i])
validation_dataX=validation_dataY=(pitches[:4,:].T)[150:]
validation_generator = TimeseriesGenerator(validation_dataX, validation_dataY, length=3, batch_size=1)
n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit_generator(generator, epochs=50, validation_data=validation_generator)
# demonstrate prediction
x_input = (pitches[:4,:].T)[155:158]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[158])
# demonstrate prediction
x_input = (pitches[:4,:].T)[151:156]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[156])
for i in range(10):
yhat = model.predict(validation_generator[i][0], verbose=0)
logger.info(i,yhat)
logger.info('expected: ', validation_generator[i][1])
和结果:
...
100 (array([[[72, 73, 69, 73],
[72, 73, 69, 73],
[72, 73, 69, 73]]]),
array([[72, 73, 69, 73]]))
101 (array([[[72, 73, 69, 73],
[72, 73, 69, 73],
[72, 73, 69, 73]]]),
array([[74, 71, 71, 71]]))
102 (array([[[72, 73, 69, 73],
[72, 73, 69, 73],
[74, 71, 71, 71]]]),
array([[74, 71, 71, 71]]))
103 (array([[[72, 73, 69, 73],
[74, 71, 71, 71],
[74, 71, 71, 71]]]),
array([[74, 71, 71, 71]]))
104 (array([[[74, 71, 71, 71],
[74, 71, 71, 71],
[74, 71, 71, 71]]]),
array([[74, 71, 71, 71]]))
105 (array([[[74, 71, 71, 71],
[74, 71, 71, 71],
[74, 71, 71, 71]]]),
array([[74, 73, 67, 71]]))
106 (array([[[74, 71, 71, 71],
[74, 71, 71, 71],
[74, 73, 67, 71]]]),
array([[74, 73, 67, 71]]))
107 (array([[[74, 71, 71, 71],
[74, 73, 67, 71],
[74, 73, 67, 71]]]),
array([[74, 73, 67, 71]]))
108 (array([[[74, 73, 67, 71],
[74, 73, 67, 71],
[74, 73, 67, 71]]]),
array([[74, 73, 67, 71]]))
109 (array([[[74, 73, 67, 71],
[74, 73, 67, 71],
[74, 73, 67, 71]]]),
array([[74, 74, 69, 76]]))
110 (array([[[74, 73, 67, 71],
[74, 73, 67, 71],
[74, 74, 69, 76]]]),
array([[74, 74, 69, 76]]))
111 (array([[[74, 73, 67, 71],
[74, 74, 69, 76],
[74, 74, 69, 76]]]),
array([[72, 74, 71, 76]]))
112 (array([[[74, 74, 69, 76],
[74, 74, 69, 76],
[72, 74, 71, 76]]]),
array([[72, 74, 71, 76]]))
113 (array([[[74, 74, 69, 76],
[72, 74, 71, 76],
[72, 74, 71, 76]]]),
array([[71, 73, 72, 71]]))
114 (array([[[72, 74, 71, 76],
[72, 74, 71, 76],
[71, 73, 72, 71]]]),
array([[71, 73, 72, 71]]))
115 (array([[[72, 74, 71, 76],
[71, 73, 72, 71],
[71, 73, 72, 71]]]),
array([[71, 73, 72, 71]]))
116 (array([[[71, 73, 72, 71],
[71, 73, 72, 71],
[71, 73, 72, 71]]]),
array([[71, 73, 72, 71]]))
117 (array([[[71, 73, 72, 71],
[71, 73, 72, 71],
[71, 73, 72, 71]]]),
array([[69, 71, 71, 73]]))
118 (array([[[71, 73, 72, 71],
[71, 73, 72, 71],
[69, 71, 71, 73]]]),
array([[69, 71, 71, 73]]))
119 (array([[[71, 73, 72, 71],
[69, 71, 71, 73],
[69, 71, 71, 73]]]),
array([[69, 71, 71, 73]]))
120 (array([[[69, 71, 71, 73],
[69, 71, 71, 73],
[69, 71, 71, 73]]]),
array([[69, 71, 71, 73]]))
121 (array([[[69, 71, 71, 73],
[69, 71, 71, 73],
[69, 71, 71, 73]]]),
array([[69, 70, 72, 68]]))
122 (array([[[69, 71, 71, 73],
[69, 71, 71, 73],
[69, 70, 72, 68]]]),
array([[69, 70, 72, 68]]))
123 (array([[[69, 71, 71, 73],
[69, 70, 72, 68],
[69, 70, 72, 68]]]),
array([[69, 70, 71, 69]]))
124 (array([[[69, 70, 72, 68],
[69, 70, 72, 68],
[69, 70, 71, 69]]]),
array([[69, 70, 71, 69]]))
125 (array([[[69, 70, 72, 68],
[69, 70, 71, 69],
[69, 70, 71, 69]]]),
array([[67, 71, 69, 71]]))
126 (array([[[69, 70, 71, 69],
[69, 70, 71, 69],
[67, 71, 69, 71]]]),
array([[67, 71, 69, 71]]))
127 (array([[[69, 70, 71, 69],
[67, 71, 69, 71],
[67, 71, 69, 71]]]),
array([[67, 71, 69, 71]]))
128 (array([[[67, 71, 69, 71],
[67, 71, 69, 71],
[67, 71, 69, 71]]]),
array([[67, 71, 69, 71]]))
129 (array([[[67, 71, 69, 71],
[67, 71, 69, 71],
[67, 71, 69, 71]]]),
array([[71, 71, 68, 69]]))
130 (array([[[67, 71, 69, 71],
[67, 71, 69, 71],
[71, 71, 68, 69]]]),
array([[71, 71, 68, 69]]))
131 (array([[[67, 71, 69, 71],
[71, 71, 68, 69],
[71, 71, 68, 69]]]),
array([[71, 71, 68, 69]]))
132 (array([[[71, 71, 68, 69],
[71, 71, 68, 69],
[71, 71, 68, 69]]]),
array([[71, 71, 68, 69]]))
133 (array([[[71, 71, 68, 69],
[71, 71, 68, 69],
[71, 71, 68, 69]]]),
array([[71, 71, 69, 68]]))
134 (array([[[71, 71, 68, 69],
[71, 71, 68, 69],
[71, 71, 69, 68]]]),
array([[71, 71, 69, 68]]))
135 (array([[[71, 71, 68, 69],
[71, 71, 69, 68],
[71, 71, 69, 68]]]),
array([[71, 71, 69, 68]]))
136 (array([[[71, 71, 69, 68],
[71, 71, 69, 68],
[71, 71, 69, 68]]]),
array([[71, 71, 69, 68]]))
137 (array([[[71, 71, 69, 68],
[71, 71, 69, 68],
[71, 71, 69, 68]]]),
array([[72, 64, 69, 68]]))
138 (array([[[71, 71, 69, 68],
[71, 71, 69, 68],
[72, 64, 69, 68]]]),
array([[72, 64, 69, 68]]))
139 (array([[[71, 71, 69, 68],
[72, 64, 69, 68],
[72, 64, 69, 68]]]),
array([[72, 64, 69, 68]]))
140 (array([[[72, 64, 69, 68],
[72, 64, 69, 68],
[72, 64, 69, 68]]]),
array([[72, 64, 69, 68]]))
141 (array([[[72, 64, 69, 68],
[72, 64, 69, 68],
[72, 64, 69, 68]]]),
array([[74, 69, 76, 66]]))
142 (array([[[72, 64, 69, 68],
[72, 64, 69, 68],
[74, 69, 76, 66]]]),
array([[74, 69, 76, 66]]))
143 (array([[[72, 64, 69, 68],
[74, 69, 76, 66],
[74, 69, 76, 66]]]),
array([[74, 69, 76, 66]]))
144 (array([[[74, 69, 76, 66],
[74, 69, 76, 66],
[74, 69, 76, 66]]]),
array([[74, 69, 76, 66]]))
145 (array([[[74, 69, 76, 66],
[74, 69, 76, 66],
[74, 69, 76, 66]]]),
array([[74, 71, 72, 69]]))
146 (array([[[74, 69, 76, 66],
[74, 69, 76, 66],
[74, 71, 72, 69]]]),
array([[74, 71, 72, 69]]))
Epoch 1/50
147/147 [==============================] - 2s 16ms/step - loss: 514.8802 - val_l
oss: 0.0082
Epoch 2/50
147/147 [==============================] - 2s 11ms/step - loss: 51.5768 - val_lo
ss: 0.0249
Epoch 3/50
147/147 [==============================] - 2s 11ms/step - loss: 71.6900 - val_lo
ss: 0.0464
Epoch 4/50
147/147 [==============================] - 2s 10ms/step - loss: 47.4575 - val_lo
ss: 0.1303
Epoch 5/50
147/147 [==============================] - 2s 10ms/step - loss: 52.6841 - val_lo
ss: 0.5772
Epoch 6/50
147/147 [==============================] - 2s 11ms/step - loss: 47.3059 - val_lo
ss: 5.2535
Epoch 7/50
147/147 [==============================] - 2s 11ms/step - loss: 43.6491 - val_lo
ss: 41.2008
Epoch 8/50
147/147 [==============================] - 2s 11ms/step - loss: 37.8593 - val_lo
ss: 28.5831
Epoch 9/50
147/147 [==============================] - 2s 11ms/step - loss: 40.8553 - val_lo
ss: 41.5958
Epoch 10/50
147/147 [==============================] - 2s 11ms/step - loss: 34.5995 - val_lo
ss: 57.3419
Epoch 11/50
147/147 [==============================] - 2s 11ms/step - loss: 34.2054 - val_lo
ss: 38.9516
Epoch 12/50
147/147 [==============================] - 2s 11ms/step - loss: 36.9247 - val_lo
ss: 38.1881
Epoch 13/50
147/147 [==============================] - 2s 10ms/step - loss: 34.5922 - val_lo
ss: 49.7601
Epoch 14/50
147/147 [==============================] - 2s 11ms/step - loss: 38.1668 - val_lo
ss: 46.0043
Epoch 15/50
147/147 [==============================] - 2s 10ms/step - loss: 35.4724 - val_lo
ss: 39.1485
Epoch 16/50
147/147 [==============================] - 2s 11ms/step - loss: 35.7787 - val_lo
ss: 38.2263
Epoch 17/50
147/147 [==============================] - 2s 11ms/step - loss: 32.5241 - val_lo
ss: 38.0783
Epoch 18/50
147/147 [==============================] - 2s 11ms/step - loss: 35.1693 - val_lo
ss: 35.3403
Epoch 19/50
147/147 [==============================] - 2s 11ms/step - loss: 34.5822 - val_lo
ss: 28.0546
Epoch 20/50
147/147 [==============================] - 2s 11ms/step - loss: 32.7388 - val_lo
ss: 37.5600
Epoch 21/50
147/147 [==============================] - 2s 11ms/step - loss: 36.7384 - val_lo
ss: 19.3809
Epoch 22/50
147/147 [==============================] - 2s 11ms/step - loss: 34.0202 - val_lo
ss: 38.0124
Epoch 23/50
147/147 [==============================] - 2s 11ms/step - loss: 31.7241 - val_lo
ss: 36.0455
Epoch 24/50
147/147 [==============================] - 2s 10ms/step - loss: 33.6021 - val_lo
ss: 19.4785
Epoch 25/50
147/147 [==============================] - 2s 11ms/step - loss: 29.5922 - val_lo
ss: 37.5662
Epoch 26/50
147/147 [==============================] - 2s 10ms/step - loss: 31.7600 - val_lo
ss: 25.8877
Epoch 27/50
147/147 [==============================] - 2s 11ms/step - loss: 31.0494 - val_lo
ss: 25.5513
Epoch 28/50
147/147 [==============================] - 2s 11ms/step - loss: 32.7150 - val_lo
ss: 22.6177
Epoch 29/50
147/147 [==============================] - 2s 11ms/step - loss: 30.3998 - val_lo
ss: 26.8450
Epoch 30/50
147/147 [==============================] - 2s 10ms/step - loss: 30.3076 - val_lo
ss: 42.8708
Epoch 31/50
147/147 [==============================] - 2s 11ms/step - loss: 30.6752 - val_lo
ss: 32.9248
Epoch 32/50
147/147 [==============================] - 2s 10ms/step - loss: 29.2235 - val_lo
ss: 33.0209
Epoch 33/50
147/147 [==============================] - 2s 11ms/step - loss: 30.7826 - val_lo
ss: 21.4303
Epoch 34/50
147/147 [==============================] - 2s 11ms/step - loss: 31.5795 - val_lo
ss: 28.7224
Epoch 35/50
147/147 [==============================] - 2s 11ms/step - loss: 29.2187 - val_lo
ss: 19.5436
Epoch 36/50
147/147 [==============================] - 2s 10ms/step - loss: 28.8158 - val_lo
ss: 23.3435
Epoch 37/50
147/147 [==============================] - 2s 10ms/step - loss: 27.8942 - val_lo
ss: 29.7689
Epoch 38/50
147/147 [==============================] - 2s 11ms/step - loss: 31.8379 - val_lo
ss: 19.7113
Epoch 39/50
147/147 [==============================] - 2s 11ms/step - loss: 29.4185 - val_lo
ss: 30.7159
Epoch 40/50
147/147 [==============================] - 2s 11ms/step - loss: 29.2826 - val_lo
ss: 22.0266
Epoch 41/50
147/147 [==============================] - 2s 11ms/step - loss: 29.3911 - val_lo
ss: 22.6929
Epoch 42/50
147/147 [==============================] - 2s 10ms/step - loss: 28.0742 - val_lo
ss: 16.1369
Epoch 43/50
147/147 [==============================] - 2s 11ms/step - loss: 27.4483 - val_lo
ss: 19.0667
Epoch 44/50
147/147 [==============================] - 2s 11ms/step - loss: 27.6157 - val_lo
ss: 15.3852
Epoch 45/50
147/147 [==============================] - 2s 11ms/step - loss: 27.9996 - val_lo
ss: 21.4107
Epoch 46/50
147/147 [==============================] - 2s 11ms/step - loss: 28.4632 - val_lo
ss: 17.0626
Epoch 47/50
147/147 [==============================] - 2s 11ms/step - loss: 29.0796 - val_lo
ss: 21.7797
Epoch 48/50
147/147 [==============================] - 2s 10ms/step - loss: 28.2646 - val_lo
ss: 21.8080
Epoch 49/50
147/147 [==============================] - 2s 11ms/step - loss: 28.7243 - val_lo
ss: 18.9899
Epoch 50/50
147/147 [==============================] - 2s 11ms/step - loss: 28.2579 - val_lo
ss: 28.6534
[[[72 73 74 68]
[71 74 76 66]
[71 74 76 66]]]
[[72.415985 69.27797 71.99651 69.86983 ]]
expected: [71 74 76 66]
[[[74 71 72 69]
[72 73 74 68]
[72 73 74 68]
[72 73 74 68]
[72 73 74 68]]]
[[153.16042 179.3388 158.57655 169.93341]]
expected: [71 74 76 66]
0 [[73.17023 69.77195 71.62949 71.44139]]
expected: [[72 73 74 68]]
1 [[72.80142 69.71678 71.557175 71.15702 ]]
expected: [[72 73 74 68]]
2 [[72.39997 69.51012 71.5443 70.574905]]
expected: [[72 73 74 68]]
3 [[72.39997 69.51012 71.5443 70.574905]]
expected: [[71 74 76 66]]
4 [[72.51985 69.45031 71.813896 70.3402 ]]
expected: [[71 74 76 66]]
5 [[72.415985 69.27797 71.99651 69.86983 ]]
expected: [[71 74 76 66]]
6 [[72.11394 68.977165 72.128334 69.17176 ]]
expected: [[71 74 76 66]]
7 [[72.11394 68.977165 72.128334 69.17176 ]]
expected: [[71 76 74 61]]
8 [[72.221664 69.22221 71.957596 68.933846]]
expected: [[71 76 74 61]]
9 [[72.15421 69.480225 71.38563 68.43072 ]]
expected: [[71 76 74 61]]
(Keras) D:\programs_data\Keras>