如何在tensorflow中反转tf.image.per_image_standardization()函数?

时间:2018-04-25 18:00:19

标签: tensorflow deep-learning normalization

Tensorflow中的

tf.image.per_image_standardization()将每个图像转换为零均值&单位方差。因此,在训练深度学习模型时,这会导致非爆炸渐变。但是当我们想要显示图像数组时,我们如何在Tensorflow中恢复这个z-score标准化步骤?

2 个答案:

答案 0 :(得分:1)

tf.image.per_image_standardization()图层将创建一些可用于恢复原始数据的内部变量。请注意,这是无证件行为,不保证保持不变。目前,您仍然可以使用下面的代码(已测试)作为参考,以获取相关的张量并恢复原始数据:

import tensorflow as tf
import numpy as np

img_size = 3
a = tf.placeholder( shape = ( img_size, img_size, 1 ), dtype = tf.float32 )
b = tf.image.per_image_standardization( a )

with tf.Session() as sess:
    tensors, tensor_names = [], []
    for l in sess.graph.get_operations():
        tensors.append( sess.graph.get_tensor_by_name( l.name + ":0" ) )
        tensor_names.append( l.name )

    #mean_t = sess.graph.get_tensor_by_name( "per_image_standardization/Mean:0" )
    #variance_t = sess.graph.get_tensor_by_name( "per_image_standardization/Sqrt:0" )

    foobar = np.reshape( np.array( range( img_size * img_size ), dtype = np.float32 ), ( img_size, img_size, 1 ) )
    res =  sess.run( tensors, feed_dict = { a : foobar } )
    #for i in xrange( len( res ) ):
    #    print( i, tensor_names[ i ] + ":" )
    #    print( res[ i ] )
    #    print()

    mean = res[ 6 ] # "per_image_standardization/Mean:0"
    variance = res[ 13 ] # "per_image_standardization/Sqrt:0"
    standardized = res[ 18 ] # "per_image_standardization:0"
    original = standardized * variance + mean
    print( original )

您可以取消注释mean_tvariance_t行,以便按名称获取相关张量的引用。 (需要对sess.run()部分进行一些重写。)您可以取消注释以for i in xrange(...开头的四行(不需要重写)来打印所有可用的创建张量以进行启发。 :)

上述代码按原样输出:

  

[[[0]
    [1.]
    [2.]]

     

[[3]
    [4.]
    [5.]]

     

[[6]
    [7.]
    [8.]]]

这正是馈送到网络的数据。

答案 1 :(得分:0)

通过“显示图像阵列”我假设你的意思是在张量板中显示它。如果是这种情况,那么您不需要做任何事情,张量板可以处理已经标准化的图像。如果您希望将原始值用于任何其他目的,为什么不在标准化之前使用该变量,例如:

img = tf.placeholder(...)
img_std = tf.image.per_image_standardization(img)

您可以以任何您认为合适的方式使用imgimg_std

如果你以某种方式有一个用于非标准化图像的用例,而上面未涉及的标准化图像那么你需要自己计算平均值和标准差,然后乘以标准差并加上平均值。请注意,tf.image.per_image_standardization使用文档中定义的adjusted_stddev

adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))