我的神经网络花太多时间训练一个纪元

时间:2019-07-17 21:10:03

标签: computer-vision conv-neural-network pytorch

我正在训练一个试图对交通标志进行分类的神经网络,但只训练一个时间段就花费了太多时间,一个时间段可能要花费30分钟以上,我已将批次大小设置为64,学习率设为0.002,则输入是3个通道的20x20像素,模型摘要显示它正在训练173,931个参数,是太多还是很好?

这是网络体系结构

import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Network(nn.Module):
  def __init__(self):
    super(Network,self).__init__()
    #Convolutional Layers
    self.conv1 = nn.Conv2d(3,16,3,padding=1)
    self.conv2 = nn.Conv2d(16,32,3,padding=1)

    #Max Pooling Layers
    self.pool = nn.MaxPool2d(2,2)

    #Linear Fully connected layers
    self.fc1 = nn.Linear(32*5*5,200)
    self.fc2 = nn.Linear(200,43)

    #Dropout
    self.dropout = nn.Dropout(p=0.25)


  def forward(self,x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))

    x = x.view(-1,32*5*5)
    x = self.dropout(x)
    x = F.relu(self.fc1(x))
    x = self.dropout(x)
    x = self.fc2(x)

    return x

这是优化程序实例

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optim = optim.SGD(model.parameters(),lr = 0.002)

这是培训代码

epochs = 20
valid_loss_min = np.Inf
print("Training the network")

for epoch in range (1,epochs+1):
  train_loss = 0
  valid_loss = 0
  model.train()

  for data,target in train_data:
    if gpu_available:
      data,target = data.cuda(),target.cuda()

    optim.zero_grad()
    output = model(data)
    loss = criterion(output,target)
    loss.backward()
    optim.step()
    train_loss += loss.item()*data.size(0)

  #########################
  ###### Validate #########
  model.eval()
  for data,target in valid_data:
    if gpu_available:
      data,target = data.cuda(),target.cuda()

    output = model(data)
    loss = criterion(output,target)
    valid_loss += loss.item()*data.size(0)

  train_loss = train_loss/len(train_data.dataset)
  valid_loss = train/len(valid_data.dataset)

  print("Epoch {}.....Train Loss = {:.6f}....Valid Loss = {:.6f}".format(epoch,train_loss,valid_loss))

  if valid_loss < valid_loss_min:
    torch.save(model.state_dict(), 'model_traffic.pt')
    print("Valid Loss min {:.6f} >>> {:.6f}".format(valid_loss_min, valid_loss))

我正在通过Google colab使用GPU

0 个答案:

没有答案