从张量采样取决于张量流中的随机变量

时间:2018-07-29 17:29:42

标签: tensorflow probability sampling

是否可以从依赖于张量流中随机变量的张量中获取样本?我需要获得近似的样本分布,以用于损失函数进行优化。具体来说,在下面的示例中,我希望能够获取Y_output的样本,以便能够计算输出分布的均值和方差,并在损失函数中使用这些参数。

def sample_weight(mean, phi, seed=1):
    P_epsilon = tf.distributions.Normal(loc=0., scale=1.0)
    epsilon_s = P_epsilon.sample([1])
    s = tf.multiply(epsilon_s, tf.log(1.0+tf.exp(phi)))
    weight_sample = mean + s
    return weight_sample

X = tf.placeholder(tf.float32, shape=[None, 1], name="X")
Y_labels = tf.placeholder(tf.float32, shape=[None, 1], name="Y_labels")
sw0 = sample_weight(u0,p0)
sw1 = sample_weight(u1,p1)
Y_output = sw0 + tf.multiply(sw1,X)

loss = tf.losses.mean_squared_error(labels=Y_labels, predictions=Y_output)
train_op = tf.train.AdamOptimizer(0.5e-1).minimize(loss)
init_op = tf.global_variables_initializer()
losses = []
predictions = []

Fx = lambda x: 0.5*x + 5.0
xrnge = 50
xs, ys = build_toy_data(funcx=Fx, stdev=2.0, num=xrnge)

with tf.Session() as sess:
    sess.run(init_op)
    iterations=1000
    for i in range(iterations):
        stat = sess.run(loss, feed_dict={X: xs, Y_labels: ys})

1 个答案:

答案 0 :(得分:0)

不确定是否能回答您的问题,但是:当您在采样Tensor的下游有Op时(例如,通过调用Op创建的P_epsilon.sample([1]) ,只要您在下游sess.run上调用Tensor,示例操作就会重新运行,并产生一个新的随机值。例如:

import tensorflow as tf
from tensorflow_probability import distributions as tfd

n = tfd.Normal(0., 1.)
s = n.sample()
y = s**2
sess = tf.Session()  # Don't actually do this -- use context manager
print(sess.run(y))
# ==> 0.13539088
print(sess.run(y))
# ==> 0.15465781
print(sess.run(y))
# ==> 4.7929106

如果您想要一堆y的示例,则可以

import tensorflow as tf
from tensorflow_probability import distributions as tfd

n = tfd.Normal(0., 1.)
s = n.sample(100)
y = s**2
sess = tf.Session()  # Don't actually do this -- use context manager
print(sess.run(y))
# ==> vector of 100 squared random normal values

我们在tensorflow_probability中也提供了一些很酷的工具来完成您在此处所做的事情。即Bijector API和trainable_distributions API。

(另一个次要点:我建议使用tf.nn.softplus,或至少使用tf.log1p(tf.exp(x))代替tf.log(1.0 + tf.exp(x))。由于浮点不精确,后者的数值特性较差,前者经过优化)。

希望这会有所帮助!