Tensorflow和Numpy不匹配数据格式

时间:2017-06-29 12:32:53

标签: python tensorflow deep-learning tensor

使用此代码:

import tensorflow as tf  

w = tf.Variable(tf.random_normal( [ 3 , 3 , 1 , 1  ], stddev = 0.01 )) 
if __name__ == '__main__': 
    initVar = tf.global_variables_initializer()  
    with tf.Session() as sess:  
        sess.run(initVar) 
        print w.eval()

由于数据格式w = tf.Variable(tf.random_normal( [kernel_height, kernel_width, input_channel, output_chhannel], stddev = 0.01 )),我希望看到这样的矩阵:

[[[[ -0.004  0.003  0.006]
   [ -0.005 -0.008  0.001]
   [  0.006  0.007  0.002]]]]

但打印出来:

[[[[ 0.001]] 
  [[-0.031]] 
  [[-0.005]]]

 [[[ 0.006]] 
  [[ 0.011]] 
  [[ 0.006]]]

 [[[ 0.008]] 
  [[-0.001]] 
  [[ 0.014]]]]

我想要的是将我的体重张量值与0和1的常数张量逐个相乘,以得到掩盖的权重,如:

w = [[[[ -0.004  0.003  0.006]
       [ -0.005 -0.008  0.001]
       [  0.006  0.007  0.002]]]]

mask = [[[[ 1  1  1]
          [ 1  1  0]
          [ 0  0  0]]]]

w * mask =  [[[[ -0.004  0.003  0.006]
               [ -0.005 -0.008  0.   ]
               [  0.     0.     0.   ]]]]

我用它的代码:

    mask = np.ones((3, 3, 1, 1), dtype=np.float32)
    mask[1, 2, :, :] = 0.
    mask[2, :, :, :] = 0. 

    weight = tf.get_variable("weight", [3, 3, 1, 1], tf.float32, tf.contrib.layers.xavier_initializer()) 

    weight *= tf.constant(mask, dtype=tf.float32)

但似乎它没有正常运作。感谢您的帮助。

1 个答案:

答案 0 :(得分:2)

你需要

w = tf.Variable(tf.random_normal([1, 1, 3, 3], stddev=0.01)) 

最后,您可以使用

import tensorflow as tf 
import numpy as np

mask = np.ones((1, 1, 3, 3), dtype=np.float32)
mask[:, :, 1, 2] = 0.
mask[:, :, 2, :] = 0. 

print(mask)

weight = tf.get_variable("weight", [3, 3, 1, 1], tf.float32, tf.contrib.layers.xavier_initializer())
weight *= tf.transpose( tf.constant(mask, dtype=tf.float32) )

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(tf.transpose(weight).eval())

你会得到

[[[[ 1.  1.  1.]
   [ 1.  1.  0.]
   [ 0.  0.  0.]]]]

[[[[ 0.88993669  0.80872607  0.57259583]
   [ 0.5067296  -0.20804334 -0.        ]
   [ 0.          0.          0.        ]]]]