Tensorflow的占位符初始化不同于tensorflow的常量初始化。为什么?

时间:2019-02-25 11:56:20

标签: python tensorflow initialization

我编写了2个函数,它们以不同的方式初始化 tensorflow 的变量。我不知道结果为何不同。这是使用占位符进行初始化的第一个函数:

第一个功能

import tensorflow as tf
import numpy as np

def linear_function():
    np.random.seed(1)

    X = tf.placeholder(dtype = tf.float64, name='X')
    W = tf.placeholder(dtype = tf.float64, name='W')
    b = tf.placeholder(dtype = tf.float64, name='b')
    Y = tf.add(tf.matmul(W, X), b)

    sess = tf.Session()

    result = sess.run(Y, feed_dict={W:np.random.randn(4,3), X:np.random.randn(3,1), b:np.random.randn(4,1)})
    sess.close()
    return result
print( "result = " + str(linear_function()))

结果是:

result = [[-1.98748544]
 [-2.76826248]
 [-0.78635415]
 [-2.77463846]]

第二个功能

第二个函数使用tf.constant初始化变量:

def linear_function():

    np.random.seed(1)

    X = tf.constant(np.random.randn(3,1), name ="X")
    W = tf.constant(np.random.randn(4,3), name ="X")
    b = tf.constant(np.random.randn(4,1), name ="X")
    Y = tf.add(tf.matmul(W,X), b)

    sess = tf.Session()
    result = sess.run(Y)

    sess.close()

    return result

print( "result = " + str(linear_function()))

结果:

result = [[-2.15657382]
 [ 2.95891446]
 [-1.08926781]
 [-0.84538042]]

出什么问题了?与np.random.seed(1)有关吗?​​

谢谢。

1 个答案:

答案 0 :(得分:1)

在第一个代码段中,feed_dict是:

{W:np.random.randn(4,3), X:np.random.randn(3,1), b:np.random.randn(4,1)}

因此,首先为W生成一个随机值,然后为Xb生成一个随机值。但是,在第二个片段中,随机值的顺序为XWb。由于生成随机数的顺序不同,因此值不同。例如,如果您在第一个代码段的feed_dict中更改了顺序,您将获得与第二个代码相同的结果:

import tensorflow as tf
import numpy as np

def linear_function():
    np.random.seed(1)

    X = tf.placeholder(dtype = tf.float64, name='X')
    W = tf.placeholder(dtype = tf.float64, name='W')
    b = tf.placeholder(dtype = tf.float64, name='b')
    Y = tf.add(tf.matmul(W, X), b)

    sess = tf.Session()

    result = sess.run(Y, feed_dict={X:np.random.randn(3,1), W:np.random.randn(4,3), b:np.random.randn(4,1)})
    sess.close()
    return result

print( "result = " + str(linear_function()))

输出:

result = [[-2.15657382]
 [ 2.95891446]
 [-1.08926781]
 [-0.84538042]]