在训练LSTM时从csv加载文件时出错

时间:2017-01-01 11:18:37

标签: python csv tensorflow deep-learning lstm

我对tensorflow相对较新,我试图在tensorflow中为我抓取的一些数据训练一个两层LSTM,并保存为CSV。但是,在我使用tensorflow网站上显示的方法后,我不断收到以下错误:

  

TypeError:输入必须是序列

原始代码是:

file = tf.train.string_input_producer(['players_raw.csv'],
num_epochs=100, shuffle=False)
reader = tf.TextLineReader()
key, val = reader.read(file)
gameNum, age, team, homeAway, opponent, pointDiff, secs, orb, drb, ast, stl, blk, to, pts, fanPts = tf.decode_csv(val, record_defaults=defaults)
features = tf.pack([gameNum, age, team, homeAway, opponent, pointDiff, secs, orb, drb, ast, stl, blk, to, pts])
label = tf.pack([fanPts]);

lstmCell = rnn_cell.LSTMCell(NUM_FEATURES)
stacked = rnn_cell.MultiRNNCell([lstmCell] * 2)
outputs, states = rnn.rnn(stacked, features, dtype=tf.float32)

最后一行是导致错误的原因。我想我明白问题是什么,但我不确定如何修复它

0 个答案:

没有答案