Pytorch TypeError:eq()接收到无效的参数组合

时间:2019-03-05 09:52:43

标签: numpy image-processing machine-learning computer-vision pytorch

num_samples = 10
def predict(x):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.numpy(), axis=1)

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(images.view(-1,28*28))
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print("accuracy: %d %%" % (100 * correct / total))

错误

correct += (predicted == labels).sum().item() TypeError: 
eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:  
* (Tensor other)
  didn't match because some of the arguments have invalid types: (!numpy.ndarray!)
* (Number other)
  didn't match because some of the arguments have invalid types: (!numpy.ndarray!)

*

1 个答案:

答案 0 :(得分:0)

您正在尝试比较predictedlabels。但是,您的predictednp.array,而labelstorch.tensor,因此eq()==运算符)无法在它们之间进行比较。
np.argmax替换为torch.argmax

 return torch.argmax(mean, dim=1)

你应该没事。