如何使用tf.constant或numpy数组初始化tf.Variable?

时间:2019-04-29 19:16:20

标签: python numpy tensorflow global-variables

我正在尝试在tf.Variable()中初始化tf.InteractiveSession()。我已经有一些经过预先训练的权重,它们是单独的numpy文件。如何使用这些numpy值有效地初始化变量?

我经历了以下选择:

  1. 使用tf.assign()
  2. 在创建sess.run()时直接使用tf.Variable()

似乎值未正确初始化。 以下是我尝试过的一些代码。让我知道哪个是正确的?

def read_numpy(file):
    return np.fromfile(file,dtype='f')

def build_network():
    with tf.get_default_graph().as_default():
        x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
        sess = tf.get_default_session()
        with sess.as_default():
            sess.run(tf.global_variables_initializer())

sess = tf.InteractiveSession()
with sess.as_default():
    build_network()

这是正确的方法吗?我已经打印了session对象,它与整个会话相同。

edit:当前似乎使用sess.run(tf.global_variables_initializer())正在调用随机初始化操作

1 个答案:

答案 0 :(得分:1)

tf.Variable()接受numpy数组作为初始值:

import tensorflow as tf
import numpy as np

init = np.ones((2, 2))
x = tf.Variable(init) # <-- set initial value to assign to a variable

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]

因此只需使用numpy数组进行初始化,无需先将其转换为张量。

或者,您也可以使用tf.Variable.load()将numpy数组中的值分配给会话上下文中的变量:

import tensorflow as tf
import numpy as np

x = tf.Variable(tf.zeros((2, 2)))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    init = np.ones((2, 2))
    x.load(init)
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]