tf.reduce_mean

时间:2019-06-23 17:52:19

标签: python tensorflow precision

在CPU和GPU上,tf.reduce_mean的数值稳定性都比np.mean差。

无论总和超过浮点类型的限制,tf.reduce_mean都会出现数值问题吗?

是否有更好的方法来计算张量流中float16数组的均值?

结果(cpu,tf 1.13.1,Linux):

np.mean 64: 0.499978537075602
np.sum  64: 499978.53707560204
np.mean 16: 0.5
np.sum  16: inf
tf.reduce_mean 16: nan

结果(GPU,计算能力5.2,TF 1.13.1,CUDA 10.1,Linux):

np.mean 64: 0.500100701606694
np.sum  64: 500100.7016066939
np.mean 16: 0.5
np.sum  16: inf
tf.reduce_mean 16: nan

结果(GPU,计算能力7.0,TF 1.13.1,CUDA 9.0,Linux):

np.mean 64: 0.4996047117607758
np.sum  64: 499604.7117607758
np.mean 16: 0.4995
np.sum  16: inf
tf.reduce_mean 16: nan

测试:

"""
Test numerical stability of reduce_mean
"""

import numpy as np
import tensorflow as tf


N = int(1e6)
dtype = np.float16

x = np.random.random(size=N)

print("np.mean 64:", np.mean(x))
print("np.sum  64:", np.sum(x))
x = x.astype(np.float16)
mean16 = np.mean(x)
print("np.mean 16:", np.mean(x))
print("np.sum  16:", np.sum(x))

with tf.Session() as sess:
    x = tf.constant(x, dtype=np.float16)
    print("tf.reduce_mean 16:",
          sess.run(tf.reduce_mean(x)))

3 个答案:

答案 0 :(得分:2)

来自numpy documentation

  

默认情况下,float16结果是使用float32中间体计算的,以提高精度。

来自tensorflow documentation

  

请注意,np.mean具有一个dtype参数,该参数可用于指定输出类型。默认情况下为dtype=float64。另一方面,tf.reduce_mean具有来自input_tensor的激进类型推断...

因此可能没有比sess.run(tf.reduce_mean(tf.cast(x, np.float32))))更好的方法了。

答案 1 :(得分:1)

我尝试了这个,得到了0.5的结果。我正在使用Tensorflow 1.13.1和GPU。

import numpy as np
import tensorflow as tf

x = np.random.random(size=10**8).astype(np.float16)
px = tf.placeholder(dtype=tf.float16, shape=(None,), name="x")

with tf.Session() as sess:
    print(sess.run(tf.reduce_mean(px), feed_dict={px: x}))

答案 2 :(得分:0)

结果:

tf.reduce_mean, tree reduction: 0.5

测试:

import numpy as np
import tensorflow as tf


N = int(1e8)
x = np.random.random(size=N).astype(np.float16)

with tf.Session() as sess:
    a = tf.reshape(x, (100, 100, 100, 100))
    for i in range(4):
        a = tf.reduce_mean(a, axis=-1)
    print("tf.reduce_mean, tree reduction:", sess.run(a))