我正在尝试在tf.Variable()
中初始化tf.InteractiveSession()
。我已经有一些经过预先训练的权重,它们是单独的numpy
文件。如何使用这些numpy
值有效地初始化变量?
我经历了以下选择:
tf.assign()
sess.run()
时直接使用tf.Variable()
似乎值未正确初始化。 以下是我尝试过的一些代码。让我知道哪个是正确的?
def read_numpy(file):
return np.fromfile(file,dtype='f')
def build_network():
with tf.get_default_graph().as_default():
x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
sess = tf.get_default_session()
with sess.as_default():
sess.run(tf.global_variables_initializer())
sess = tf.InteractiveSession()
with sess.as_default():
build_network()
这是正确的方法吗?我已经打印了session
对象,它与整个会话相同。
edit:当前似乎使用sess.run(tf.global_variables_initializer())
正在调用随机初始化操作
答案 0 :(得分:1)
tf.Variable()
接受numpy数组作为初始值:
import tensorflow as tf
import numpy as np
init = np.ones((2, 2))
x = tf.Variable(init) # <-- set initial value to assign to a variable
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
print(x.eval())
# [[1. 1.]
# [1. 1.]]
因此只需使用numpy数组进行初始化,无需先将其转换为张量。
或者,您也可以使用tf.Variable.load()
将numpy数组中的值分配给会话上下文中的变量:
import tensorflow as tf
import numpy as np
x = tf.Variable(tf.zeros((2, 2)))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
init = np.ones((2, 2))
x.load(init)
print(x.eval())
# [[1. 1.]
# [1. 1.]]