如何在TensorFlow中使用tf.get_variable和numpy值初始化变量?

时间:2016-07-04 18:19:10

标签: python numpy tensorflow

我想用numpy值初始化网络上的一些变量。为了这个例子考虑:

init=np.random.rand(1,2)
tf.get_variable('var_name',initializer=init)

当我这样做时,我收到错误:

ValueError: Shape of a new variable (var_name) must be fully defined, but instead was <unknown>.

为什么我收到了这个错误?

为了尝试修复它,我尝试了:

tf.get_variable('var_name',initializer=init, shape=[1,2])

产生了更奇怪的错误:

TypeError: 'numpy.ndarray' object is not callable

我尝试阅读the docs and examples,但它并没有真正帮助。

是否无法使用TensorFlow中的get_variable方法使用numpy数组初始化变量?

3 个答案:

答案 0 :(得分:36)

以下作品:

init = tf.constant(np.random.rand(1, 2))
tf.get_variable('var_name', initializer=init)

get_variable的文档确实有点缺乏。仅供您参考,initializer参数必须是TensorFlow Tensor对象(可以通过在您的案例中tf.constant值上调用numpy来构建),或者一个'callable',它接受两个参数shapedtype,它应该返回的值的形状和数据类型。同样,在您的情况下,如果您想使用“可调用”机制,可以编写以下内容:

init = lambda shape, dtype: np.random.rand(*shape)
tf.tf.get_variable('var_name', initializer=init, shape=[1, 2])

答案 1 :(得分:10)

@keveman回答得很好,作为补充,有 tf.get_variable(&#39; var_name&#39;,initializer = init)的用法,张量流文档确实提供了一个全面的例子

import numpy as np
import tensorflow as tf

value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)

print('fitting shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [2, 4], initializer = init)
    x.initializer.run()
    print(x.eval())

    fitting shape :
[[0.  1.  2.  3.]
[4.  5.  6.  7.]]

print('larger shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [3, 4], initializer = init)
    x.initializer.run()
    print(x.eval())

    larger shape :
[[0.  1.  2.  3.]
[4.  5.  6.  7.]
[7.  7.  7.  7.]]

print('smaller shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [2, 3], initializer = init)

    * <b>`ValueError`< / b > : Too many elements provided.Needed at most 6, but received 8

https://www.tensorflow.org/api_docs/python/tf/constant_initializer

答案 2 :(得分:4)

如果已经创建了变量(即来自某个复杂函数),只需使用load

https://www.tensorflow.org/api_docs/python/tf/Variable#load

x_var = tf.Variable(tf.zeros((1, 2), tf.float32))
x_val = np.random.rand(1,2).astype(np.float32)

sess = tf.Session()
x_var.load(x_val, session=sess)

# test
assert np.all(sess.run(x_var) == x_val)