我是一名Python初学者,正在研究有关神经网络和PY火炬模块的教程。我不太了解这条线的行为。
import torch.nn as nn
loss = nn.MSELoss()
print(loss)
>>MSELoss()
由于nn.MSELoss是一个类,为什么不将其称为变量loss而不将其实例化为类对象? MSELoss类中的哪种类型的代码可以实现此行为?
答案 0 :(得分:0)
它确实实例化了一个类。但是,该类实现了特殊的__call__
方法,该方法使您可以像使用函数一样在其上使用调用运算符()
。它还实现了__repr__
方法,该方法可以自定义打印时的外观。
答案 1 :(得分:0)
根据Documentation,nn.MSELoss()
创建了一个衡量均方误差的标准,您可以使用以下方式:
loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)
output.backward()
您可以检查loss
是MSELoss
类:
print(type(loss).__name__)
>>> MSELoss
答案 2 :(得分:0)
打印某些对象时,实际上是在Python中调用其__str__
方法,或者如果未定义该对象,则调用__repr__
(来自 representation )。
在您的情况下,它是关于 normal 类的,但是它是__repr__
has been overriden:
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str