LSTM具有12个输入和输出节点,没有嵌入

时间:2015-06-30 07:10:16

标签: python neural-network

我需要的是一个RNN LSTM,它有12个输入节点,12个输出节点,并且能够调整隐藏层(它们的数量和大小)。

输入和输出向量的元素可以是实数或整数(输入时我有整数)。是否有必要在这里使用一个热编码(因为它将无法使用这么多组合)?我认为这个层是多余的,因为我的输入已经是向量。

我无法用Python(Lasagne,Block,Keras ......)和Torch构建这个神经网络。

我到目前为止最接近的是Pybrain,但是这个软件包处于“maintance”模式(只有bug修复),非常慢(它不是在Theano上构建)并且只支持一个LSTM层,这是不够的。但至少它做了我想要的 - 它需要一个12个数字的向量,并返回另一个12个数字的向量。

以下是使用pybrain的示例:

# Preparing data
from pybrain.datasets import SequentialDataSet
from itertools import cycle

sp = 4000
data = np.random.randint(1,100,(5000,12))
def splt_seq(data):
    sq = SequentialDataSet(12, 12)
    for sample, next_sample in zip(data, cycle(data[1:])):
        sq.addSample(sample, next_sample)
    return(sq)

train = splt_seq(data[:sp])
test = splt_seq(data[sp:sp+200])


# Building network and training
from pybrain.tools.shortcuts import buildNetwork
from pybrain.structure.modules import LSTMLayer

net = buildNetwork(12, 100, 12,
           hiddenclass=LSTMLayer, outputbias=False, recurrent=True)

from pybrain.supervised import RPropMinusTrainer
from sys import stdout

trainer = RPropMinusTrainer(net, dataset=train)
train_errors = [] # save errors for plotting later
EPOCHS_PER_CYCLE = 5
CYCLES = 100
EPOCHS = EPOCHS_PER_CYCLE * CYCLES
for i in range(CYCLES):
    trainer.trainEpochs(EPOCHS_PER_CYCLE)
    train_errors.append(trainer.testOnData())
    epoch = (i+1) * EPOCHS_PER_CYCLE
    print("\r epoch {}/{}".format(epoch, EPOCHS), end="")
    stdout.flush()

print()
print("final error =", train_errors[-1])

net.activate(X_test.getSample()[0])

使用keras我可以看到here

1 个答案:

答案 0 :(得分:1)

这取决于。除了长度和类型之外,您没有指定输入的详细信息。输入矢量的元素是离散的还是连续的?如果它们是离散的,则必须对矢量进行单热编码。否则,您可以直接将数据提供给RNN-LSTM。

离散和连续之间的区别:

假设您的向量包含有关从一副牌中随机挑选的12张牌的信息。你可能已经将你的牌从0到51编入索引(牌组中有52张牌),输入的矢量看起来像这样:

[3,4,17,50,20,10,11,36,5,0,23,49]

现在你遇到了问题,每张卡片的指数并不代表卡片的任何量化指标(卡片50不是卡片10的5倍以上)。因此,您必须对矢量进行一次热编码,以便卡片将居住在更大的空间中,彼此等距: [e(3,52),e(4,52),e(17,52)...... e(23,52)]

如果您的输入包含连续数据,如天气信息,(12个元素中的每一个都是不同的因素,如温度,风,湿度等),一个热编码就没有任何意义。只需按原样将矢量输入RNN-LSTM。

正如您所提到的,输入也可能是实数,您的输入更可能是连续的,不需要进行热编码。 希望有所帮助!