我有一个矩阵,其中每一行都是一个分布。我想从theano中的这个矩阵的每一行中采样一个索引。
我可以从单个向量中进行采样,请参阅下面代码中的函数f
。
但是我无法将其扩展到矩阵,f_mat
不起作用。
import theano
import theano.tensor as tt
import theano.tensor.shared_randomstreams
import numpy as np
n = 3
seed = 1
rng = tt.shared_randomstreams.RandomStreams(seed=seed)
dist1 = tt.fvector('dist')
i = rng.choice(size=(1,), a=dist1.shape[0], p=dist1)
f = theano.function([dist1], [i])
_dist = np.asarray([0.5, 0.5, 0.])
print f(_dist.astype(np.float32)) #THIS IS WORKING
dist_mat = tt.fmatrix('dist_mat')
i_mat = rng.choice(size=(dist_mat.shape[0],), a=dist_mat.shape[1], p=dist_mat, ndim=1)
f_mat = theano.function([dist_mat], [i_mat])
_dist_mat = np.asarray([[0.5, 0.1, 0.4], [0.1, 0.1, 0.8]])
print f_mat(_dist_mat.astype(np.float32)) #THIS DOES NOT WORK
我做错了什么?我希望f_mat
生成一个与_dist_mat
矩阵中的行数具有相同维度的向量。