ValueError: 预期输入 batch_size (59) 匹配目标 batch_size (1)

时间:2021-06-20 07:30:22

标签: deep-learning pytorch image-segmentation pytorch-lightning

我正在尝试使用 pytorch 构建语义分割模型。但是,我遇到了这个错误,不知道如何解决。

这是模型:

class SegmentationNN(pl.LightningModule):

    def __init__(self, num_classes=23, hparams=None):
        super().__init__()
        self.hparams = hparams
        self.model=models.alexnet(pretrained=True).features
        self.conv=nn.Conv2d(256, 3, kernel_size=1)
        self.upsample = nn.Upsample(size=(240,240))


    def forward(self, x):

        print('Input:', x.shape)
        x = self.model(x)
        print('After Alexnet convs:', x.shape)
        x = self.conv(x)
        print('After 1-conv:', x.shape)
        x = self.upsample(x)
        print('After upsampling:', x.shape)
        return x

    def training_step(self, batch, batch_idx):
        images, targets = batch
       # targets = targets.view(targets.size(0), -1)
        out = self.forward(images)
        loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')
        loss = loss_func(out, targets.unsqueeze(0))
        tensorboard_logs = {'loss': loss}
        
        return {'loss': loss, 'log':tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        images, targets = batch
       # targets = targets.view(targets.size(0), -1)
        out = self.forward(images)
        loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')
        loss = loss_func(out, targets.unsqueeze(0))
        tensorboard_logs = {'loss': loss}
        
        return {'loss': loss, 'log':tensorboard_logs}
    
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams['learning_rate'])
        return optim

这是训练和健身:

train_dataloader = DataLoader(train_data, batch_size=hparams['batch_size'])
val_dataloader = DataLoader(val_data, batch_size=hparams['batch_size'])

trainer = pl.Trainer(
    max_epochs=50,
    gpus=1 if torch.cuda.is_available() else None
)
pass
trainer.fit(model, train_dataloader, val_dataloader)

这些是每一层后张量的大小:

Input: torch.Size([59, 3, 240, 240])
After Alexnet convs: torch.Size([59, 256, 6, 6])
After 1-conv: torch.Size([59, 3, 6, 6])
After upsampling: torch.Size([59, 3, 240, 240])

我是 Pytorch 和 Pytorch Lightning 的初学者,所以我会很感激每一个建议!

1 个答案:

答案 0 :(得分:0)

你能在这里删除 unsqueeze(0) 部分吗:loss = loss_func(out, targets.unsqueeze(0))