如何在Pytorch中进行批量学习?

时间:2019-06-19 01:30:33

标签: python neural-network pytorch

当您查看pytorch代码内部网络架构的构建方式时,我们需要扩展torch.nn.Module__init__内部,我们定义网络模块,而pytorch将跟踪以下内容的梯度这些模块的参数。然后,在forward函数中,我们定义如何对网络进行前向传递。

我在这里不明白的是批处理学习将如何发生。在上面的定义(包括forward函数)中,我们都不关心网络输入的批量大小。要执行批处理学习,我们唯一需要设置的就是在输入中添加一个与批处理大小相对应的额外维度,但是如果我们进行批处理学习,则网络定义中的任何内容都不会更改。至少,这是我在代码here中看到的东西。

因此,如果到目前为止我所解释的所有事情都是正确的(如果您对我有误解,请让我知道,我将不胜感激),如果在定义中未声明任何关于批量大小的信息,将如何执行批量学习我们的网络类(继承torch.nn.Module的类)?具体来说,我很想知道当我们仅将nn.MSELoss设置为批处理尺寸时,如何在pytorch中实现批处理梯度下降算法。

1 个答案:

答案 0 :(得分:1)

检查:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()         

    def forward(self, x):
        print("Hi ma")        
        print(x)
        x = F.relu(x)
        return x

n = Net()
r = n(torch.tensor(-1))
print(r)
r = n.forward(torch.tensor(1)) #not planned to call directly
print(r)

退出:

Hi ma
tensor(-1)
tensor(0)
Hi ma
tensor(1)
tensor(1)

要记住的是,forward不应直接调用。 PyTorch使此模块对象n可调用。他们实现了 callable ,例如:

 def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        hook(self, input)
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            raise RuntimeError(
                "forward hooks should never return any values, but '{}'"
                "didn't return None".format(hook))
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

只有n()会自动呼叫forward

通常,__init__定义模块结构,forward()定义单个批次的操作。

如果某些结构元素需要重复该操作,或者像我们x = F.relu(x)一样直接在张量上调用函数。

您获得了如此出色的表现,因为PyTorch已针对这种方式进行了优化,因此PyTorch中的所有功能都将分批处理(迷你批处理)。

这意味着当您读取图像时,您不会读取单个图像,而是读取bs批次的图像。