使用 PyTorch-ligtning 训练 RNN 时出现 AssertionError

时间:2021-05-12 23:34:22

标签: python pytorch pytorch-lightning

我是 PyTorch 的新手,所以我使用 PyTorch-Lightning 来训练简单的(Vanilla)RNN:

1.数据准备

import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import numpy as np
import pandas as pd

#...
#X_train,Y_train are np arrs with shape (n,t,d)
X_train_tensors = torch.Tensor(X_train).to(device)
Y_train_tensors = torch.Tensor(Y_train).to(device)
#create train dataset
train = TensorDataset(X_train_tensors, Y_train_tensors)
trainloader = DataLoader(train, batch_size=32, shuffle=False)

2.创建Learner类

#use pl to create learner
class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = nn.MSELoss()(y_hat, y)
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

3.创建模型并使用训练器

NN = nn.Sequential(
            nn.RNN(1,3,nonlinearity='tanh',batch_first=True),
            nn.RNN(3,5,nonlinearity='tanh',batch_first=True),
            nn.Linear(5, 1)
        )

model = Learner(NN)
trainer = pl.Trainer(max_epochs=100, weights_summary='full')
trainer.fit(model, train_dataloader=trainloader)

我有这个断言错误:

AssertionError                            Traceback (most recent call last)
<ipython-input-29-781e293c05ed> in <module>()
     10 model = Learner(NN)
     11 trainer = pl.Trainer(max_epochs=100, weights_summary='full')
---> 12 trainer.fit(model, train_dataloader=trainloader)
     13 #only when called it uses the test loop
     14 trainer.test(model, testloader)

16 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    242             max_batch_size = int(batch_sizes[0])
    243         else:
--> 244             assert isinstance(input, Tensor)
    245             batch_sizes = None
    246             max_batch_size = input.size(0) if self.batch_first else input.size(1)

当我检查 github https://github.com/pytorch/pytorch/blob/d09abf004cc16f8fd5f320e3d5d07c383c174ea7/torch/nn/modules/rnn.py#L247 中的 baseRNN 时,我没有发现 assert!

你能帮忙吗?

1 个答案:

答案 0 :(得分:0)

昨天为这个错误添加了一个修复程序(这就是为什么你没有在 github 代码中看到它)。在此处查看相关问题的 PR:

https://github.com/pytorch/pytorch/issues/55868