tf.nn.moments如何计算方差?

时间:2017-08-04 10:27:00

标签: tensorflow

查看测试示例:

import tensorflow as tf

x = tf.constant([[1,2],[3,4],[5,6]])
mean, variance = tf.nn.moments(x, [0])
with tf.Session() as sess:
    m, v = sess.run([mean, variance])
    print(m, v)

输出结果为:

[3 4]

[2 2]

我们想要计算沿轴0的方差,第一列是[1,3,5],平均值=(1 + 3 + 5)/ 3 = 3,它是正确的,方差= [(1 -3)^ 2 +(3-3)^ 2 +(5-3)^ 2] /3=2.6666,但输出为2,谁可以告诉我tf.nn.moments如何计算方差?

顺便说一下,查看API DOCshift做了什么?

2 个答案:

答案 0 :(得分:3)

问题是x是一个整数张量而TensorFlow不是强制转换,而是在不改变类型的情况下尽可能好地执行计算(因此输出也是整数)。您可以在x的构造中传递浮点数或指定tf.constantdtype参数:

x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)

然后你得到预期的结果:

import tensorflow as tf

x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)
mean, variance = tf.nn.moments(x, [0])
with tf.Session() as sess:
    m, v = sess.run([mean, variance])
    print(m, v)
>>> [ 3.  4.] [ 2.66666675  2.66666675]

关于shift参数,它似乎允许您指定一个值,以及" shift"输入。通过移位,它们意味着减去,因此如果您的输入为[1., 2., 4.]并且您提供了shift,例如2.5,TensorFlow将首先减去该数量并计算来自{{1}的时刻}。一般来说,将其保留为[-1.5, 0.5, 1.5]似乎是安全的,它将通过输入的平均值执行移位,但我想可能存在给出预定移位值的情况(例如,如果您知道或拥有输入平均值的近似概念可以产生更好的数值稳定性。

答案 1 :(得分:0)

# Replace the following line with correct data dtype
x = tf.constant([[1,2],[3,4],[5,6]])
# suppose you don't want tensorflow to trim the decimal then use float data type. 
x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)

Results: array([ 2.66666675,  2.66666675], dtype=float32)

注意:从原始实现shift is not used

开始