PyTorch设计:为什么torch.distributions.multivariate_normal在其类之外具有方法?

时间:2019-04-30 18:55:31

标签: python pytorch

我试图更好地理解pytorch的设计。我试图从多元法线中提取样本,然后找到了torch.distributions.multivariate_normal,令我惊讶的是,该模块在其MultivariateNormal()类之外定义了许多受保护的函数。

我对为什么会如此感到困惑。为什么不将所有这些函数都定义为MultivariateNormal()类中的类方法?这样,我们可以通过

实例化此类的对象。
torch.distributions.multivariate_normal(mu,sigma)

而不是

torch.distributions.multivariate_normal.MultivariateNormal(mu,sigma).

有什么想法吗?

谢谢。

1 个答案:

答案 0 :(得分:1)

您可以直接调用MultivariateNormal:

import torch
gaussian = torch.distributions.MultivariateNormal(torch.ones(2),torch.eye(2))

但是类 MultivariateNormal 已在文件“ torch / distributions / multivariate_normal.py”中实现,因此两个调用都是正确的