可变错误O'Reilly编程PyTorch

时间:2019-10-21 17:06:49

标签: deep-learning pytorch

我正在阅读O'Reilly在2019年9月出版的《 Programming Pytorch ..》中的描述一种用于图像分类的简单线性神经网络。

targettargets相比,开放模型中的变量名称存在错误(无后顾之忧),但是似乎省略了变量声明{{1 }}(还有train_iterator,未显示)。

我想知道他们(我认为)打算使用dev_iterator变量是什么?

p27

train_iterator

所以,..

def train(model, optimiser, loss_fn, train_loader, val_loader, epochs=20, device='cpu'):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, target = batch # Bug here for 'target'
            inputs = inputs.to(device)
            target = targets.to(device)
            output = model(inputs)
            loss = loss_fin(output, target)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item()
        training_loss /= len(train_iterator) # What is train_iterator?

必须

    inputs, target = batch

在培训步骤(未显示)下面的验证步骤中

    inputs, targets = batch

没什么大不了的,代码只是分配给Cudas(GPU)或CPU。

变量 inputs, targets = batch ... targets = targets.to(device) 定义训练损失(重要诊断)。我假设在纪元和批处理迭代器之间应该声明一个迭代器,还是在训练循环中?

注释 train_iterator仅指Pytorch Dataloader。该模型涉及具有ReLU激活功能的3个线性层。

0 个答案:

没有答案