如何使用jax(jit)从python中的泊松分布生成样本数组?

时间:2019-06-26 21:28:13

标签: python jit poisson

我正在使用jax(https://github.com/google/jax)来编码神经网络,并模拟我的输入,我想根据泊松分布生成样本数组。鉴于jax的限制,我该怎么办?

我已经尝试使用np.random.poisson(mu,N)和scipy.stats.poisson.rvs(mu,size = N)。这些都不起作用,因为jax.numpy和jax.scipy.stats不支持它们。因此,基本上,我需要一个替代解决方案,或者使用jax支持的另一个程序包,或者对poisson函数进行硬编码。

from jax import jit, vmap
import jax.numpy as np
import numpy as onp
from scipy.stats import poisson

def build_input_and_targets_simulated(ntime, key):
    """
    Function: Simulate inputs and targets.
    Args:
        ntime: number of time steps in input
        key: key for random number generator
    Returns:
        inputs: txu matrix of inputs
        targets: txu matrix of target classifications
    """
    mu = 0.03  # average number of events per interval

    # scipy method
    inputs = np.array([poisson.rvs(mu, size=ntime), poisson.rvs(mu, 
size=ntime)]).T

    # numpy method
    inputs = np.array([onp.random.poisson(mu, ntime), 
onp.random.poisson(mu, ntime)]).T

    targets = onp.zeros((ntime,1))

    # determine target based on difference in inputs
    diffT = np.cumsum(inputs[:,0]) - np.cumsum(inputs[:,1]) # calculate 
cumulative difference in inputs for each time point

    targets[diffT > 0] = 1 # if diffT > 0, set binary choice
    targets[diffT ==0] = int(random.randint(key, (1,1), 0, 2)) # if 
inputs are equal, select target class randomly

    return inputs, targets

# Now batch it and jit.
build_input_and_target = build_input_and_targets_simulated
build_inputs_and_targets = vmap(build_input_and_target, in_axes=(None, 
0))
build_inputs_and_targets_jit = jit(build_inputs_and_targets, 
static_argnums=(0,))


seed = onp.random.randint(0, 1000000)
key = random.PRNGKey(seed)
ntimesteps = 25

inputs, targets = build_inputs_and_targets_jit(ntimesteps, key)

如果我使用scipy方法,则会收到如下错误:

Exception: Tracer can't be used with raw numpy functions. You might have
  import numpy as np
instead of
  import jax.numpy as np

如果我使用numpy方法,则会收到如下错误:

TypeError: 'BatchTracer' object cannot be interpreted as an integer

这两个错误似乎都与jax使用的专用数据类型有关(使用jit时是必需的)。

我该如何解决?

0 个答案:

没有答案