提高喀拉拉邦GRU的准确性

时间:2020-10-29 11:28:01

标签: tensorflow keras

我正在尝试使用以前演奏的音符来预测钢琴的下一个音符。输入和目标数据(来自古典钢琴.mid文件的值)的格式为

x_train=[ [[1,2,3,4,5],[0,0,0,0,0],[0,0,0,0,0]], [[1,2,3,4,5],[4,5,6,7,8],[0,0,0,0,0]] ]
#notes not played yet are [0,0,0,0,0]
#y_train is the next note played
y_train= [ [[4,5,6,7,8]], [[10,11,12,13,14]] ]

问题:我的准确性很低(〜45%),并且预测的下一个音符始终相同(或最终变为相同)

培训-

(x_train,y_train)=create_data()

x_train=np.array(x_train)
y_train=np.array(y_train)

x_train=x_train.astype("int")
y_train=y_train.astype("int")

x_train=x_train[:500]
y_train=y_train[:500]

model=keras.Sequential()
model.add(keras.Input(shape=(500,5)))
model.add(keras.layers.GRU(5,activation='linear'))
model.add(keras.layers.Dense(1*5))

model.compile(
        loss=keras.losses.MeanAbsoluteError(),
        optimizer=keras.optimizers.Adam(lr=0.001),
        metrics=["accuracy"]
    )

model.fit(x_train,y_train)

使用从x_train提取的第一音符来创作歌曲:

currentNote=x_train[0].tolist()

i=0
while i<499:
    feed=[currentNote]
    feed=np.array(feed)
    output=model.predict(feed)
    output=np.absolute(output)
    output=output[0].astype("int").tolist()
    print(output)                       # printing next note predicted
    currentNote[i+1]=output
    i+=1

1 个答案:

答案 0 :(得分:0)

首先,您使用了错误的损失。您应该使用categorical_crossentropysparse_categorical_crossentropy,因为这是分类问题。另外,您的最终激活功能(未指定)应该为'softmax'

此外,您应该从输出概率中采样,而不是获取最高概率。有一个Tensorflow tutorial涵盖了这一点。

注意:从此分布中取样很重要,因为采用分布的argmax可以很容易地使模型陷入循环。

这大致是您可以做到的方式:

example_batch_predictions = model(X_test)

sampled_indices = tf.random.categorical(example_batch_predictions[0], 
                                        num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()
array([41, 60,  3, 31, 47, 21, 61,  6, 56, 42, 39, 40, 52, 60, 37, 37, 27,
       11,  6, 56, 64, 62, 43, 42,  6, 34,  1, 30, 16, 45, 46, 11, 17,  8,
       26,  8,  1, 46, 37, 21, 37, 53, 34, 49,  5, 58, 11,  9, 42, 62, 14,
       56, 56, 30, 31, 32, 63, 53, 10, 23, 35,  5, 19, 19, 46,  3, 23, 63,
       61, 11, 57,  0, 35, 48, 32,  4, 37,  7, 48, 23, 39, 30, 20, 26,  1,
       52, 57, 23, 46, 56, 11, 22,  7, 47, 16, 27, 38, 51, 55, 28])