我有7个连续输入变量,我想估计1个连续变量(y = f(x_1, ... x_8)
)。
该数据集大约有26000个量度,但很快就会增长到数百万个量度。
我有每个小节的时间。
我成功地使用PyTorch构建了具有线性层和ReLU的神经网络,但是我想将过去100种测量方法考虑在内。
我考虑构建RNN,尤其是GRU或LSTM,因为我发现它们不如Elman RNN麻烦。
我建立了有史以来最简单的类:
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super(RNN, self).__init__() # compulsory for pytorch
# Parameters
self.nHiddenFeatures = hidden_size
self.nLayers = num_layers
self.nHiddenNeurons = hidden_size
# Layers
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, X, hidden):
# initial hidden input sequence
Y, hidden_ = self.gru(X, hidden)
Y = Y.contiguous().view(-1, self.hidden_dim)
Y = self.fc(Y)
return F.relu(Y), hidden_ # output variable can't be negative
# Parameters:
input_size = 7
output_size = 1
hidden_size = 4
num_layers = 2
seq_length = 100
# model
rnn = RNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(rnn.parameters(), lr=learningRate)
loss_func = nn.MSELoss()
但是现在我正在努力弄清楚如何训练它:特别是在为X_train数据集准备正确的形状时。我目前的情况是X_train.shape = torch.Size([26000, 7])
我记得其中26000是时间戳记的数量,7是在每个时间戳记中测量的变量的数量。
我已经阅读了有关PyTorch RNN的文档,并且我知道输入张量的形状应为(seq_len, batch, input_size)
:
我的问题是如何从当前张量创建一个输入张量,其中序列在考虑前100个时间戳的情况下运行所有时间戳:我是对的,试图在序列重叠的地方构建此张量因此它的形状大约是X_train.shape = torch.Size([100, 26000, 7])
,还是我应该像在互联网上的许多示例中看到的那样创建单独的序列,以X_train.shape = torch.Size([100, 260, 7])
结尾?在我看来,此解决方案似乎仅考虑260个时间戳中的上次时间戳,而不考虑其余时间戳...
最终目标是接受定期培训,例如:
hidden = torch.zeros(..., ..., ...) # initialization
for epoch in range(nEpochs):
optimizer.zero_grad() # set gradient to zero in each step
Y_estTrain= rnn(X_train, hidden) # prediction of all samples
loss = loss_func(Y_estTrain, Y_trueTrain) # difference between predicted and expected
loss.backward(loss)
optimizer.step() # update weights of the NN
if epoch % printLossOnceEvery == 0: # check loss decrease during training
print(f"Epoch = {epoch}, MSE = {loss.item():0.1f}")
感谢您阅读整个问题:有点长,但是我希望添加尽可能多的信息。