我已经完成了MNIST数据集的PyTorch MLP模型,但得到了两个不同的结果:使用PyTorch的MNIST数据集时精度为0.90+,但使用Keras的MNIST数据集时精度为~0.10。 下面是我的依赖代码:PyTorch 0.3.0.post4,keras 2.1.3,tensorflow后端1.4.1 gpu版本。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch as pt
import torchvision as ptv
from keras.datasets import mnist
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
# training data from PyTorch
train_set = ptv.datasets.MNIST("./data/mnist/train", train=True, transform=ptv.transforms.ToTensor(), download=True)
test_set = ptv.datasets.MNIST("./data/mnist/test", train=False, transform=ptv.transforms.ToTensor(), download=True)
train_dataset = DataLoader(train_set, batch_size=100, shuffle=True)
test_dataset = DataLoader(test_set, batch_size=10000, shuffle=True)
class MLP(pt.nn.Module):
"""The Multi-layer perceptron"""
def __init__(self):
super(MLP, self).__init__()
self.fc1 = pt.nn.Linear(784, 512)
self.fc2 = pt.nn.Linear(512, 128)
self.fc3 = pt.nn.Linear(128, 10)
self.use_gpu = True
def forward(self, din):
din = din.view(-1, 28 * 28)
dout = F.relu(self.fc1(din))
dout = F.relu(self.fc2(dout))
# return F.softmax(self.fc3(dout))
return self.fc3(dout)
model = MLP().cuda()
print(model)
# loss func and optim
optimizer = pt.optim.SGD(model.parameters(), lr=1)
criterion = pt.nn.CrossEntropyLoss().cuda()
def evaluate_acc(pred, label):
pred = pred.cpu().data.numpy()
label = label.cpu().data.numpy()
test_np = (np.argmax(pred, 1) == label)
test_np = np.float32(test_np)
return np.mean(test_np)
def evaluate_loader(loader):
print("evaluating ...")
accurarcy_list = []
for i, (inputs, labels) in enumerate(loader):
inputs = pt.autograd.Variable(inputs).cuda()
labels = pt.autograd.Variable(labels).cuda()
outputs = model(inputs)
accurarcy_list.append(evaluate_acc(outputs, labels))
print(sum(accurarcy_list) / len(accurarcy_list))
def training(d, epochs):
for x in range(epochs):
for i, data in enumerate(d):
optimizer.zero_grad()
(inputs, labels) = data
inputs = pt.autograd.Variable(inputs).cuda()
labels = pt.autograd.Variable(labels).cuda()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if i % 200 == 0:
print(i, ":", evaluate_acc(outputs, labels))
# Training MLP for 4 epochs with MNIST dataset from PyTorch
training(train_dataset, 4)
# The accuracy is ~0.96.
evaluate_loader(test_dataset)
print("###########################################################")
def load_mnist():
(x, y), (x_test, y_test) = mnist.load_data()
x = x.reshape((-1, 1, 28, 28)).astype(np.float32)
x_test = x_test.reshape((-1, 1, 28, 28)).astype(np.float32)
y = y.astype(np.int64)
y_test = y_test.astype(np.int64)
print("x.shape", x.shape, "y.shape", y.shape,
"\nx_test.shape", x_test.shape, "y_test.shape", y_test.shape,
)
return x, y, x_test, y_test
class TMPDataset(Dataset):
"""Dateset for loading Keras MNIST dataset."""
def __init__(self, a, b):
self.x = a
self.y = b
def __getitem__(self, item):
return self.x[item], self.y[item]
def __len__(self):
return len(self.y)
x_train, y_train, x_test, y_test = load_mnist()
# Create dataloader for MNIST dataset from Keras.
test_loader = DataLoader(TMPDataset(x_test, y_test), num_workers=1, batch_size=10000)
train_loader = DataLoader(TMPDataset(x_train, y_train), shuffle=True, batch_size=100)
# Evaluate the performance of MLP trained on PyTorch dataset and the accurach is ~0.96.
evaluate_loader(test_loader)
evaluate_loader(train_loader)
model = MLP().cuda()
print(model)
optimizer = pt.optim.SGD(model.parameters(), lr=1)
criterion = pt.nn.CrossEntropyLoss().cuda()
# Train now on MNIST dataset from Keras.
training(train_loader, 4)
# Evaluate the trianed model on MNIST dataset from Keras and result in performance ~0.10...
evaluate_loader(test_loader)
evaluate_loader(train_loader)
我检查了Keras MNIST数据集中的一些样本,发现没有错误。 我想知道数据集有什么问题? 代码可以无错误地运行,运行它以查看结果。
答案 0 :(得分:3)
来自Keras的MNIST数据不标准化;在Keras MNIST MLP example之后,您应该手动执行此操作,即您应在load_data()
函数中包含以下内容:
x /= 255
x_test /= 255
不确定PyTorch,但似乎来自他们自己的实用程序函数的MNIST数据已经标准化(就像Tensorflow的情况一样 - 请参阅我的回答here中的第三点)。
在非标准化输入数据的情况下,10%的准确度(即相当于随机猜测)是完全一致的。