加载数据时维度不匹配

时间:2021-01-03 09:31:08

标签: deep-learning pytorch

我的代码如下。

import torch.nn.functional as f
train_on_gpu=True
class CnnLstm(nn.Module):
    def __init__(self):
        super(CnnLstm, self).__init__()
        self.cnn = CNN()
        self.rnn = nn.LSTM(
            input_size=180000,
            hidden_size=256,
            num_layers=2,
            batch_first=True)
        self.linear = nn.Linear(hidden_size, num_classes) 

    def forward(self, x):
        print('before forward ')
        print(x.shape)
        batch_size, time_steps, channels, height, width = x.size()
        c_in = x.view(batch_size * time_steps, channels, height, width)
        _, c_out = self.cnn(c_in)
        r_in = c_out.view(batch_size, time_steps, -1)
        r_out, (_, _) = self.rnn(r_in)
        r_out2 = self.linear(r_out[:, -1, :])
        return f.log_softmax(r_out2, dim=1)


cnnlstm_model = CnnLstm().to(device)
optimizer = torch.optim.Adam(cnnlstm_model.parameters(), lr=learning_rate)
#optimizer = torch.optim.SGD(cnnlstm_model.parameters(), lr=learning_rate)
#criterion = nn.functional.nll_loss()
criterion = nn.CrossEntropyLoss()
# Train the model
n_total_steps = len(train_dl)
num_epochs = 20
for epoch in range(num_epochs):
    t_losses=[]
    for i, (images, labels) in enumerate(train_dl):  
        # origin shape: [5, 3, 300, 300]
        # resized: [5, 300, 300]
        print('load data '+str(images.shape))
        images = np.expand_dims(images, axis=1)
        print('after expand ')
        print(images.shape)
        images = torch.FloatTensor(images)
        images, labels = images.cuda(), labels.cuda()
        images, labels = Variable(images), Variable(labels)
        optimizer.zero_grad()
        outputs = cnnlstm_model(images)
        loss = criterion(outputs, labels)
        t_losses.append(loss)
        loss.backward()
        optimizer.step()

三个地方正在打印。

(1)

print('load data '+str(images.shape))

(2)

print('after expand ')
print(images.shape)

(3)

print('before forward ')
print(x.shape)

我有 5 个图像的批量大小。 加载 2629 个批次,只有最后一个批次有问题。 前一批加载图片没有问题,加载为

load data torch.Size([5, 3, 300, 300])
after expand 
(5, 1, 3, 300, 300)
before forward 
torch.Size([5, 1, 3, 300, 300])
load data torch.Size([5, 3, 300, 300])
after expand 
(5, 1, 3, 300, 300)
before forward 
torch.Size([5, 1, 3, 300, 300])
.
.
.
load data torch.Size([5, 3, 300, 300])
after expand 
(5, 1, 3, 300, 300)
before forward 
torch.Size([5, 1, 3, 300, 300])
load data torch.Size([5, 3, 300, 300])
after expand 
(5, 1, 3, 300, 300)
before forward 
torch.Size([5, 1, 3, 300, 300])

在最后一批加载时,

load data torch.Size([5, 3, 300, 300])
after expand 
(5, 1, 3, 300, 300)
before forward 
torch.Size([5, 1, 3, 300, 300])
before forward 
torch.Size([15, 300, 300])

为什么我要打印两次“before forward”日志?此外,它的形状不一样。

可能有什么问题?

编辑:

这是加载数据的代码。

inputH = input_size
inputW = input_size
#Data transform (normalization & data augmentation)
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_resize_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),
                         tt.ToTensor(),
                         tt.Normalize(*stats)])
train_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),                         
                         tt.RandomHorizontalFlip(),                                                                         
                         tt.ToTensor(),
                         tt.Normalize(*stats)])
valid_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),
                         tt.ToTensor(), 
                         tt.Normalize(*stats)])
test_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),
                        tt.ToTensor(), 
                        tt.Normalize(*stats)])

#Create dataset
train_ds = ImageFolder('./data/train', train_tfms)
valid_ds = ImageFolder('./data/valid', valid_tfms)
test_ds = ImageFolder('./data/test', test_tfms)

from torch.utils.data.dataloader import DataLoader
batch_size = 5

#Training data loader
train_dl = DataLoader(train_ds, batch_size, shuffle = True, num_workers = 8, pin_memory=True)
#Validation data loader
valid_dl = DataLoader(valid_ds, batch_size, shuffle = True, num_workers = 8, pin_memory=True)
#Test data loader
test_dl = DataLoader(test_ds, 1, shuffle = False, num_workers = 1, pin_memory=True)

1 个答案:

答案 0 :(得分:0)

我对数据加载器进行了一些更改,终于成功了。

    class DataLoader:
        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        @staticmethod
        def get_train_data(batch_size):
            train_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),                         
                                 tt.RandomHorizontalFlip(),                                                                          
                                 tt.ToTensor(),
                                 tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            train_ds = ImageFolder('./data/train', train_tfms)
            return torch.utils.data.DataLoader(
                train_ds,
                batch_size=batch_size,
                shuffle=True,
                num_workers = 8, 
                pin_memory=True)
        
        @staticmethod
        def get_validate_data(valid_batch_size):
            valid_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),
                                 tt.ToTensor(), 
                                 tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            valid_ds = ImageFolder('./data/valid', valid_tfms)
            return torch.utils.data.DataLoader(
                valid_ds,
                batch_size=valid_batch_size,
                shuffle=True,
                num_workers = 8, 
                pin_memory=True)
    
        @staticmethod
        def get_test_data(test_batch_size):
            test_tfms = tt.Compose([tt.Resize((inputH, inputW), interpolation=2),
                                tt.ToTensor(), 
                                tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            test_ds = ImageFolder('./data/test', test_tfms)
            return torch.utils.data.DataLoader(
                test_ds,
                batch_size=test_batch_size,
                shuffle=False,
                num_workers = 1, 
                pin_memory=True)