我是 PyTorch 的新手,所以我使用 PyTorch-Lightning 来训练简单的(Vanilla)RNN:
1.数据准备
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import numpy as np
import pandas as pd
#...
#X_train,Y_train are np arrs with shape (n,t,d)
X_train_tensors = torch.Tensor(X_train).to(device)
Y_train_tensors = torch.Tensor(Y_train).to(device)
#create train dataset
train = TensorDataset(X_train_tensors, Y_train_tensors)
trainloader = DataLoader(train, batch_size=32, shuffle=False)
2.创建Learner类
#use pl to create learner
class Learner(pl.LightningModule):
def __init__(self, model:nn.Module):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.MSELoss()(y_hat, y)
logs = {'train_loss': loss}
return {'loss': loss, 'log': logs}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
3.创建模型并使用训练器
NN = nn.Sequential(
nn.RNN(1,3,nonlinearity='tanh',batch_first=True),
nn.RNN(3,5,nonlinearity='tanh',batch_first=True),
nn.Linear(5, 1)
)
model = Learner(NN)
trainer = pl.Trainer(max_epochs=100, weights_summary='full')
trainer.fit(model, train_dataloader=trainloader)
我有这个断言错误:
AssertionError Traceback (most recent call last)
<ipython-input-29-781e293c05ed> in <module>()
10 model = Learner(NN)
11 trainer = pl.Trainer(max_epochs=100, weights_summary='full')
---> 12 trainer.fit(model, train_dataloader=trainloader)
13 #only when called it uses the test loop
14 trainer.test(model, testloader)
16 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
242 max_batch_size = int(batch_sizes[0])
243 else:
--> 244 assert isinstance(input, Tensor)
245 batch_sizes = None
246 max_batch_size = input.size(0) if self.batch_first else input.size(1)
当我检查 github https://github.com/pytorch/pytorch/blob/d09abf004cc16f8fd5f320e3d5d07c383c174ea7/torch/nn/modules/rnn.py#L247 中的 baseRNN 时,我没有发现 assert!
你能帮忙吗?
答案 0 :(得分:0)
昨天为这个错误添加了一个修复程序(这就是为什么你没有在 github 代码中看到它)。在此处查看相关问题的 PR: