from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from tensorflow.examples.tutorials.mnist import input_data
import torch.optim as optim
import tensorflow.python.util.deprecation as deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
import matplotlib.pyplot as plt
%matplotlib inline
from plot import plot_loss_and_acc
mnist = input_data.read_data_sets("MNIST_data", one_hot=False)
batch_size = 250
epoch_num = 10
lr = 0.0001
disp_freq = 20
def next_batch(train=True):
# Reads the next batch of MNIST images and labels and returns them
if train:
batch_img, batch_label = mnist.train.next_batch(batch_size)
else:
batch_img, batch_label = mnist.test.next_batch(batch_size)
batch_label = torch.from_numpy(batch_label).long() # convert the numpy array into torch tensor
batch_label = Variable(batch_label) # create a torch variable
batch_img = torch.from_numpy(batch_img).float() # convert the numpy array into torch tensor
batch_img = Variable(batch_img) # create a torch variable
return batch_img, batch_label
class MLP(nn.Module):
def __init__(self, n_features, n_classes):
super(MLP, self).__init__()
self.layer1 = nn.Linear(n_features, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_classes)
def forward(self, x, training=True):
# a neural network with 2 hidden layers
# x -> FC -> relu -> dropout -> FC -> relu -> dropout -> FC -> output
x = F.relu(self.layer1(x))
x = F.dropout(x, 0.5, training=training)
x = F.relu(self.layer2(x))
x = F.dropout(x, 0.5, training=training)
x = self.layer3(x)
return x
def predict(self, x):
# a function to predict the labels of a batch of inputs
x = F.softmax(self.forward(x, training=False))
return x
def accuracy(self, x, y):
# a function to calculate the accuracy of label prediction for a batch of inputs
# x: a batch of inputs
# y: the true labels associated with x
prediction = self.predict(x)
maxs, indices = torch.max(prediction, 1)
acc = 100 * torch.sum(torch.eq(indices.float(), y.float()).float())/y.size()[0]
print(acc.data)
return acc.data
# define the neural network (multilayer perceptron)
net = MLP(784, 10)
# calculate the number of batches per epoch
batch_per_ep = mnist.train.num_examples // batch_size
# define the loss (criterion) and create an optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
print(' ')
print("__________Training__________________")
xArray = []
yLoss = []
yAcc = []
for ep in range(epoch_num): # epochs loop
for batch_n in range(batch_per_ep): # batches loop
features, labels = next_batch()
# Reset gradients
optimizer.zero_grad()
# Forward pass
output = net(features)
loss = criterion(output, labels)
# Backward pass and updates
loss.backward() # calculate the gradients (backpropagation)
optimizer.step() # update the weights
if batch_n % disp_freq == 0:
print('epoch: {} - batch: {}/{} '.format(ep, batch_n, batch_per_ep))
xArray.append(ep)
yLoss.append(loss.data)
#yAcc.append(acc.data)
print('loss: ', loss.data)
print('__________________________________')
# test the accuracy on a batch of test data
features, labels = next_batch(train=False)
print("Result")
print('Test accuracy: ', net.accuracy(features, labels))
print('loss: ', loss.data)
accuracy = net.accuracy(features, labels)
#Loss Plot
# plotting the points
plt.plot(xArray, yLoss)
# naming the x axis
plt.xlabel('epoch')
# naming the y axis
plt.ylabel('loss')
# giving a title to my graph
plt.title('Loss Plot')
# function to show the plot
plt.show()
#Accuracy Plot
# plotting the points
plt.plot(xArray, yAcc)
# naming the x axis
plt.xlabel('epoch')
# naming the y axis
plt.ylabel(' accuracy')
# giving a title to my graph
plt.title('Accuracy Plot ')
# function to show the plot
plt.show()
我想显示我的训练数据集的准确性。我已经设法显示和绘制损失,但为了准确起见,我没有这样做。我知道我缺少1或2行代码,而且我不知道该怎么做。
我的意思是,如果我可以像损失一样在每个时期旁边显示准确性,那么我可以自己绘制。
答案 0 :(得分:1)
您好,将代码print('epoch: {} - batch: {}/{} '.format(ep, batch_n, batch_per_ep))
替换为
print('epoch: {} - batch: {}/{} - accuracy: {}'.format(ep, batch_n, batch_per_ep, net.accuracy(features,labels)))
希望这会有所帮助。