这是Google Magenta package
的上下文,特别是旋律RNN model
。
我尝试使用自己的数据集训练basic_rnn并且运行良好,生成一个可用的检查点。但是,当我尝试使用attention_rnn时,通过添加" attn_length = 40"对于hparams,我在训练期间得到了错误" NaN损失。"。我已经尝试将attn_length更改为其他值,如10或20,我仍然会收到此错误。此外,我确保使用" attention_rnn"创建数据集。参数,所以不应该是一个问题。
任何人都有类似的问题吗?
以下是我使用的命令:
convert_dir_to_note_sequences
--input_dir=$INPUT_DIRECTORY
--output_file=$SEQUENCES_TFRECORD
--recursive
melody_rnn_create_dataset --config="attention_rnn" --input=".../mono_notesequences.tfrecord" --output_dir="..." --eval_ratio="0.10"
python ${MODEL}/melody_rnn_train.py --config=attention_rnn --run_dir=${OUTPUT} --sequence_example_file=${INPUT}/attention_rnn/training_melodies.tfrecord --hparams="batch_size=128,rnn_layer_sizes=[512,512],attn_length=40" --num_training_steps=20000