我试图更好地理解pytorch的设计。我试图从多元法线中提取样本,然后找到了torch.distributions.multivariate_normal,令我惊讶的是,该模块在其MultivariateNormal()类之外定义了许多受保护的函数。
我对为什么会如此感到困惑。为什么不将所有这些函数都定义为MultivariateNormal()类中的类方法?这样,我们可以通过
实例化此类的对象。torch.distributions.multivariate_normal(mu,sigma)
而不是
torch.distributions.multivariate_normal.MultivariateNormal(mu,sigma).
有什么想法吗?
谢谢。
答案 0 :(得分:1)
您可以直接调用MultivariateNormal:
import torch
gaussian = torch.distributions.MultivariateNormal(torch.ones(2),torch.eye(2))
但是类 MultivariateNormal 已在文件“ torch / distributions / multivariate_normal.py”中实现,因此两个调用都是正确的