为什么我们将nn.Module作为参数传递给神经网络的类定义?

时间:2019-06-01 09:56:32

标签: machine-learning module pytorch

我想了解为什么当我们为GAN的神经网络定义类时,为什么将torch.nn.Module作为参数传递?

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

2 个答案:

答案 0 :(得分:2)

此行

class Generator(nn.Module):

简单意味着Generator类将继承nn.Module类,它不是参数。

但是,笨拙的 init 方法:

def __init__(self, input_size, hidden_size, output_size, f):

具有自我,这就是为什么您可能会将此视为论点。

这是Python类实例self。它应该留下还是应该去进行一些艰苦的战斗,但是Guido在他的博客why it has to stay中解释说。

答案 1 :(得分:0)

我们本质上是使用nn.Module(及其功能)定义类'Generator'。在编程中,我们将此称为继承(带有super(Generator, self).__init__())。

您可以在此处了解更多信息:https://www.w3schools.com/python/python_inheritance.asp