预期输入的batch_size(1)要与目标batch_size(10)相匹配

时间:2020-08-28 09:18:50

标签: python pytorch mnist

我正在尝试在PyTorch中运行MNIST数据集,我正在使用交叉熵损失和简单的神经元神经网络(748,512,128,10)。但我收到此错误:

ValueError: Expected input batch_size (1) to match target batch_size (10).

我的模特:

class Netz(nn.Module):
  def __init__(self,n_input_features):
    super(Netz,self).__init__()
    self.linear=nn.Linear(784,512,bias=True)
    self.l1=nn.Linear(512,128,bias=True)
    self.l2=nn.Linear(128,10,bias=True)
  def forward(self,x):
    x=F.relu(self.linear(x))
    x=F.relu(self.l1(x))
    return F.softmax(self.l2(x),dim=1)
model=Netz(784)

数据预处理:

mnist = keras.datasets.mnist
#Copying data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
#One-hot encoding the labels
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
#Flattening the images
x_train_reshaped = x_train.reshape((60000,784))
x_test_reshaped = x_test.reshape((10000,784))
#Normalizing the inputs
x_train_nml = x_train_reshaped/255.0 
x_test_nml = x_test_reshaped/255.0
comb_data = np.hstack([x_train_nml,y_train])
#np.random.shuffle(comb_data)
x_train = comb_data[:,:-10]
y_train = comb_data[:,-10:]
x_train=torch.from_numpy(x_train.astype(np.float32))
x_test=torch.from_numpy(x_test.astype(np.float32))
y_train=torch.from_numpy(y_train.astype(np.float32))
y_test=torch.from_numpy(y_test.astype(np.float32))

数据加载:

class Data(Dataset):
    def __init__(self):
        self.x=x_train
        self.y=y_train
        self.len=self.x.shape[0]
    def __getitem__(self,index):
      return self.x[index],self.y[index]
    def __len__(self):
        return self.len

主要:

criterion=nn.CrossEntropyLoss()
print(criterion)
optimizer=torch.optim.SGD(model.parameters(),lr=0.05)
dataset=Data()
train_data=DataLoader(dataset=dataset,batch_size=1,shuffle=False)
num_epochs=5
for epoch in range(num_epochs):
  for x,y in train_data:
    x = x.view(x.shape[0], -1)
    y_pred=model(x)
    #y=y.squeeze_()
    print(y_pred)
    loss=criterion(y_pred,y.flatten())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

完整错误:

ValueError                                Traceback (most recent call last)
<ipython-input-85-142a974fa5cd> in <module>()
     10     #y=y.squeeze_()
     11     print(y_pred)
---> 12     loss=criterion(y_pred,y.flatten())
     13     loss.backward()
     14     optimizer.step()

3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
    946     def forward(self, input: Tensor, target: Tensor) -> Tensor:
    947         return F.cross_entropy(input, target, weight=self.weight,
--> 948                                ignore_index=self.ignore_index, reduction=self.reduction)
    949 
    950 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2420     if size_average is not None or reduce is not None:
   2421         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2422     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2423 
   2424 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2214     if input.size(0) != target.size(0):
   2215         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2216                          .format(input.size(0), target.size(0)))
   2217     if dim == 2:
   2218         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (1) to match target batch_size (10).

有人可以告诉我如何解决此错误吗?

1 个答案:

答案 0 :(得分:0)

我相信你

loss=criterion(y_pred,y.flatten())

应修改为

loss=criterion(y_pred.squeeze(0),y.flatten())

以匹配第一维尺寸。