我建立了一个非常简单的模型来预测给定旋律的和谐。旋律表示为单热矢量,和谐为k-hot矢量(稀疏矢量,最多5个)。一批和声/旋律具有(batch_size, 64, 48)
形状。
不幸的是,在几次迭代(~100)之后,这个模型开始在各处预测零(prediction
张量充满零)并永远停留在这里。损失减少到某一点,它似乎只是停留在这里(可能是因为我的输出充满了零)。
手动检查一些记录后,数据集似乎没问题。我的数据集包含~10k记录。
这是我的模特:
import itertools
import tensorflow as tf
from dataset import get_batch
timesteps = 64
melody_dim = 48
harmony_dim = 48
batch_size = tf.placeholder(dtype=tf.int32)
melody_batch, harmony_batch = get_batch(['./dataset.tfrecords'], batch_size=batch_size)
cell = tf.contrib.rnn.LSTMCell(128)
output, state = tf.nn.dynamic_rnn(cell, melody_batch, dtype=tf.float32)
logits = tf.contrib.layers.fully_connected(output, harmony_dim, activation_fn=None)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=harmony_batch, logits=logits))
solver = tf.train.AdamOptimizer(1e-4).minimize(loss)
prediction = tf.greater(logits, 0)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in itertools.count():
_, curr_loss = sess.run([solver, loss], {
batch_size: 32
})
coord.request_stop()
coord.join(threads)
有人知道问题出在哪里吗?我应该使用嵌入或类似的东西吗?我怀疑问题是因为我的标签(和声)向量非常稀疏。