LSTM预测正弦波

时间:2017-08-16 10:01:30

标签: mxnet

这里我想在MxNet中生成LSTM的教程用法,以及Tensorflow的示例。 (位于https://github.com/mouradmourafiq/tensorflow-lstm-regression/blob/master/lstm_sin.ipynb“ 这是我的主要代码

import mxnet as mx
import numpy as np
import pandas as pd
import argparse
import os
import sys
from data_processing import generate_data
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
TIMESTEPS = 3
BATCH_SIZE = 100
X, y = generate_data(np.sin, np.linspace(0, 100, 10000), TIMESTEPS, seperate=False)
train_iter = mx.io.NDArrayIter(X['train'], y['train'], batch_size=BATCH_SIZE, shuffle=True, label_name='lro_label')
eval_iter = mx.io.NDArrayIter(X['val'], y['val'], batch_size=BATCH_SIZE, shuffle=False)
test_iter = mx.io.NDArrayIter(X['test'], batch_size=BATCH_SIZE, shuffle=False)
num_layers = 3
num_hidden = 50

data = mx.sym.Variable('data')
label = mx.sym.Variable('lro_label')

stack = mx.rnn.SequentialRNNCell()
for i in range(num_layers):
    stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
#stack.reset()
outputs, states = stack.unroll(length=TIMESTEPS,
                               inputs=data,
                               layout='NTC',
                               merge_outputs=True)

outputs = mx.sym.reshape(outputs, shape=(BATCH_SIZE, -1))
# purpose of fc1 was to make shape change to (batch_size, *), or label shape won't match LSTM unrolled output shape.
outputs = mx.sym.FullyConnected(data=outputs, num_hidden=1, name='fc1')
label = mx.sym.reshape(label, shape=(-1,))
outputs = mx.sym.LinearRegressionOutput(data=outputs, 
                               label=label,
                               name='lro')
contexts = mx.cpu(0)
model = mx.mod.Module(symbol = outputs,
                     data_names = ['data'],
                     label_names = ['lro_label'])
model.fit(train_iter, eval_iter,
         optimizer_params = {'learning_rate':0.005},
         num_epoch=4,
         batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 2))

此代码运行但train_accuracy是Nan。 问题是如何使其正确? 由于展开的形状有sequence_length,它如何与标签形状匹配?我的FC1网是否有意义?

1 个答案:

答案 0 :(得分:1)

auto_reset=False传递给Speedometer回调,比如,batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 2, auto_reset=False),应该修复NaN train-acc。