我正在使用jax(https://github.com/google/jax)来编码神经网络,并模拟我的输入,我想根据泊松分布生成样本数组。鉴于jax的限制,我该怎么办?
我已经尝试使用np.random.poisson(mu,N)和scipy.stats.poisson.rvs(mu,size = N)。这些都不起作用,因为jax.numpy和jax.scipy.stats不支持它们。因此,基本上,我需要一个替代解决方案,或者使用jax支持的另一个程序包,或者对poisson函数进行硬编码。
from jax import jit, vmap
import jax.numpy as np
import numpy as onp
from scipy.stats import poisson
def build_input_and_targets_simulated(ntime, key):
"""
Function: Simulate inputs and targets.
Args:
ntime: number of time steps in input
key: key for random number generator
Returns:
inputs: txu matrix of inputs
targets: txu matrix of target classifications
"""
mu = 0.03 # average number of events per interval
# scipy method
inputs = np.array([poisson.rvs(mu, size=ntime), poisson.rvs(mu,
size=ntime)]).T
# numpy method
inputs = np.array([onp.random.poisson(mu, ntime),
onp.random.poisson(mu, ntime)]).T
targets = onp.zeros((ntime,1))
# determine target based on difference in inputs
diffT = np.cumsum(inputs[:,0]) - np.cumsum(inputs[:,1]) # calculate
cumulative difference in inputs for each time point
targets[diffT > 0] = 1 # if diffT > 0, set binary choice
targets[diffT ==0] = int(random.randint(key, (1,1), 0, 2)) # if
inputs are equal, select target class randomly
return inputs, targets
# Now batch it and jit.
build_input_and_target = build_input_and_targets_simulated
build_inputs_and_targets = vmap(build_input_and_target, in_axes=(None,
0))
build_inputs_and_targets_jit = jit(build_inputs_and_targets,
static_argnums=(0,))
seed = onp.random.randint(0, 1000000)
key = random.PRNGKey(seed)
ntimesteps = 25
inputs, targets = build_inputs_and_targets_jit(ntimesteps, key)
如果我使用scipy方法,则会收到如下错误:
Exception: Tracer can't be used with raw numpy functions. You might have
import numpy as np
instead of
import jax.numpy as np
如果我使用numpy方法,则会收到如下错误:
TypeError: 'BatchTracer' object cannot be interpreted as an integer
这两个错误似乎都与jax使用的专用数据类型有关(使用jit时是必需的)。
我该如何解决?