我正试图在爱德华编码一个简单的高斯混合物。我按照基础步骤和教程。代码似乎没问题,但是当我运行它时,我收到以下错误:
Received a label value of 1 which is outside the valid range of [0, 2). Label values: 2 2 2 2...
这是我正在使用的代码。这很简单:
def build_toy_dataset(N):
pi = np.array([0.4, 0.6])
mus = [1, -1]
stds = [0.1, 0.1]
x = np.zeros(N, dtype=np.float32)
for n in range(N):
k = np.argmax(np.random.multinomial(1, pi))
x[n] = np.random.normal(mus[k], stds[k])
return x
N = 500 # number of data points
K = 2 # number of components
D = 1 # dimensionality of data
x_train = build_toy_dataset(N)
pi = Dirichlet(np.zeros(K, dtype=np.float32))
mu = Normal(loc = 0.0, scale = 1.0, sample_shape = K)
sigma = InverseGamma(1.0, 2.0, sample_shape=K)
x = ParamMixture(pi, {'loc': mu, 'scale': tf.sqrt(sigma)},
Normal,
sample_shape=N)
z = x.cat
qpi = Dirichlet(tf.Variable(tf.ones(K) / K))
qmu = Normal(loc = tf.Variable(0.0), scale = tf.nn.softplus(tf.Variable(0.0)), sample_shape = K)
qsigma = InverseGamma(tf.Variable(2.0), tf.Variable(2.0), sample_shape = K)
qz = Categorical(logits = tf.Variable(tf.zeros([N,K])))
inference = ed.ScoreKLqp({pi: qpi, mu: qmu, sigma: qsigma, z: qz},
data={x: x_train})
inference.initialize()
sess = ed.get_session()
tf.global_variables_initializer().run()
for _ in range(inference.n_iter):
info_dict = inference.update()
inference.print_progress(info_dict)