复制Tensorflow正态分布对象

时间:2018-02-02 16:53:05

标签: random tensorflow tensorboard normal-distribution

给出以下代码:

import tensorflow as tf

normal_dist = tf.contrib.distributions.Normal(.5, 1.3)
foo = normal_dist.sample()
bar = normal_dist.sample()
baz = foo + bar

sess = tf.Session()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter("./logs", graph=tf.get_default_graph())

正常分布对象被复制两次,总共三个对象是坏的,因为你可以看到它们都是相同的分布(相同的意思,标准)。

有没有办法不复制?还是一种优化方式?寻找最佳实践。

three normal distributions

1 个答案:

答案 0 :(得分:0)

sample()方法创建一个 new 张量,它从分布中接收随机值。在幕后,Normal使用tf.random_normal op,它本身也会在调用时在图表中创建一个新节点。

如果不想每次都创建新的操作,您可以简单地多次评估相同的随机张量:

...
with tf.Session() as sess:
  print(sess.run(foo))
  print(sess.run(foo))
  print(sess.run(foo))

...每次都会输出不同的随机值。

顺便提一下,请注意张量板图片上的Normal_1Normal_2不是对象,但命名范围包含用于计算值的操作(您可以展开并放大以查看)。底部Normal也是一个范围,其中包含foobar的一些常见张量。