我正在尝试使用不同的方法(TensorFlow,PyTorch和从头开始)实现2层神经网络,然后根据MNIST数据集比较它们的性能。
我不确定我犯了什么错误,但是PyTorch的准确率只有大约10%,这基本上是随机猜测。我认为权重可能根本没有更新。
请注意,我有意使用TensorFlow提供的数据集,以通过3种不同的方法使我使用的数据保持一致,以进行准确比较。
from tensorflow.examples.tutorials.mnist import input_data
import torch
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
# x -> (batch_size, 784)
x = torch.relu(x)
# x -> (batch_size, 10)
x = torch.softmax(x, dim=1)
return x
net = Net()
net.zero_grad()
Loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(1000): # loop over the dataset multiple times
batch_xs, batch_ys = mnist_m.train.next_batch(100)
# convert to appropriate settins
# note the input to the linear layer should be (n_sample, n_features)
batch_xs = torch.tensor(batch_xs, requires_grad=True)
# batch_ys -> (batch_size,)
batch_ys = torch.tensor(batch_ys, dtype=torch.int64)
# forward
# output -> (batch_size, 10)
output = net(batch_xs)
# result -> (batch_size,)
result = torch.argmax(output, dim=1)
loss = Loss(output, batch_ys)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
答案 0 :(得分:5)
这里的问题是您没有应用完全连接的层fc1
和fc2
。
您的forward()
当前如下:
def forward(self, x):
# x -> (batch_size, 784)
x = torch.relu(x)
# x -> (batch_size, 10)
x = torch.softmax(x, dim=1)
return x
因此,如果将其更改为:
def forward(self, x):
# x -> (batch_size, 784)
x = self.fc1(x) # added layer fc1
x = torch.relu(x)
# x -> (batch_size, 10)
x = self.fc2(x) # added layer fc2
x = torch.softmax(x, dim=1)
return x
应该可以。
关于Umang Guptas的回答:正如我所看到的,像机器人先生一样,在致电zero_grad()
之前先致电backward()
是很好的。这应该没问题。
编辑:
因此,我做了一个简短的测试-我设置了从1000
到10000
的迭代,以查看整体情况,如果它确实在减少。 (当然,我也将数据加载到了mnist_m
,因为您发布的代码中未包含该数据)
我在代码中添加了打印条件:
if epoch % 1000 == 0:
print('Epoch', epoch, '- Loss:', round(loss.item(), 3))
每1000
次迭代输出损失:
Epoch 0 - Loss: 2.305
Epoch 1000 - Loss: 2.263
Epoch 2000 - Loss: 2.187
Epoch 3000 - Loss: 2.024
Epoch 4000 - Loss: 1.819
Epoch 5000 - Loss: 1.699
Epoch 6000 - Loss: 1.699
Epoch 7000 - Loss: 1.656
Epoch 8000 - Loss: 1.675
Epoch 9000 - Loss: 1.659
使用PyTorch 0.4.1版进行了测试
因此,您可以看到更改了forward()
的网络现在正在学习,剩下的代码我都保持不变。
祝你好运!