如何在不使用SciPy的情况下从截断的高斯分布中采样?

时间:2019-02-09 18:34:16

标签: python scipy gaussian

SciPy是一个巨大的图书馆。令人尴尬的是,为了使用简单的功能(即计算截断的分布),我必须安装(并导入)23 MB的代码。

有一些解决方案可以更简单地解决此问题吗?

1 个答案:

答案 0 :(得分:0)

您可以通过inverse transform sampling手动实现它。您基本上可以根据从0到1之间的均匀分布得出的值来计算累积分布函数的逆。

import numpy as np

def normal(x, mu, sig):
    return 1. / (np.sqrt(2 * np.pi) * sig) * np.exp(-0.5 * np.square(x - mu) / np.square(sig))


def trunc_normal(x, mu, sig, bounds=None):
    if bounds is None: 
        bounds = (-np.inf, np.inf)

    norm = normal(x, mu, sig)
    norm[x < bounds[0]] = 0
    norm[x > bounds[1]] = 0

    return norm


def sample_trunc(n, mu, sig, bounds=None):
    """ Sample `n` points from truncated normal distribution """
    x = np.linspace(mu - 5. * sig, mu + 5. * sig, 10000)
    y = trunc_normal(x, mu, sig, bounds)
    y_cum = np.cumsum(y) / y.sum()

    yrand = np.random.rand(n)
    sample = np.interp(yrand, y_cum, x)

    return sample


# Example
import matplotlib.pyplot as plt
samples = sample_trunc(10000, 0, 1, (-1, 1))
plt.hist(samples, bins=100)
plt.show()